示例#1
0
class OWChoropleth(widget.OWWidget):
    name = 'Choropleth'
    description = 'A thematic map in which areas are shaded in proportion ' \
                  'to the measurement of the statistical variable being displayed.'
    icon = "icons/Choropleth.svg"
    priority = 120

    inputs = [("Data", Table, "set_data", widget.Default)]

    outputs = [("Selected Data", Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Table)]

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    AGG_FUNCS = (
        'Count',
        'Count defined',
        'Sum',
        'Mean',
        'Median',
        'Mode',
        'Max',
        'Min',
        'Std',
    )
    AGG_FUNCS_TRANSFORM = {
        'Count': 'size',
        'Count defined': 'count',
        'Mode': lambda x: stats.mode(x, nan_policy='omit').mode[0],
    }
    AGG_FUNCS_DISCRETE = ('Count', 'Count defined', 'Mode')
    AGG_FUNCS_CANT_TIME = ('Count', 'Count defined', 'Sum', 'Std')

    autocommit = settings.Setting(True)
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    attr = settings.ContextSetting('')
    agg_func = settings.ContextSetting(AGG_FUNCS[0])
    admin = settings.Setting(0)
    opacity = settings.Setting(70)
    color_steps = settings.Setting(5)
    color_quantization = settings.Setting('equidistant')
    show_labels = settings.Setting(True)
    show_legend = settings.Setting(True)
    show_details = settings.Setting(True)
    selection = settings.ContextSetting([])

    class Error(widget.OWWidget.Error):
        aggregation_discrete = widget.Msg(
            "Only certain types of aggregation defined on categorical attributes: {}"
        )

    class Warning(widget.OWWidget.Warning):
        logarithmic_nonpositive = widget.Msg(
            "Logarithmic quantization requires all values > 0. Using 'equidistant' quantization instead."
        )

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletChoropleth(self)
        self.mainArea.layout().addWidget(map)
        self.selection = []
        self.data = None
        self.latlon = None
        self.result_min_nonpositive = False

        def selectionChanged(selection):
            self._indices = self.ids.isin(selection).nonzero()[0]
            self.selection = selection
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        box = gui.vBox(self.controlArea, 'Aggregation')

        self._latlon_model = DomainModel(parent=self,
                                         valid_types=ContinuousVariable)
        self._combo_lat = combo = gui.comboBox(box,
                                               self,
                                               'lat_attr',
                                               orientation=Qt.Horizontal,
                                               label='Latitude:',
                                               sendSelectedValue=True,
                                               callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_lon = combo = gui.comboBox(box,
                                               self,
                                               'lon_attr',
                                               orientation=Qt.Horizontal,
                                               label='Longitude:',
                                               sendSelectedValue=True,
                                               callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_attr = combo = gui.comboBox(box,
                                                self,
                                                'attr',
                                                orientation=Qt.Horizontal,
                                                label='Attribute:',
                                                sendSelectedValue=True,
                                                callback=self.aggregate)
        combo.setModel(
            DomainModel(parent=self,
                        valid_types=(ContinuousVariable, DiscreteVariable)))

        gui.comboBox(box,
                     self,
                     'agg_func',
                     orientation=Qt.Horizontal,
                     items=self.AGG_FUNCS,
                     label='Aggregation:',
                     sendSelectedValue=True,
                     callback=self.aggregate)

        self._detail_slider = gui.hSlider(box,
                                          self,
                                          'admin',
                                          None,
                                          0,
                                          2,
                                          1,
                                          label='Administrative level:',
                                          labelFormat=' %d',
                                          callback=self.aggregate)

        box = gui.vBox(self.controlArea, 'Visualization')

        gui.spin(box,
                 self,
                 'color_steps',
                 3,
                 15,
                 1,
                 label='Color steps:',
                 callback=lambda: self.map.set_color_steps(self.color_steps))

        def _set_quantization():
            self.Warning.logarithmic_nonpositive(
                shown=(self.color_quantization.startswith('log')
                       and self.result_min_nonpositive))
            self.map.set_quantization(self.color_quantization)

        gui.comboBox(box,
                     self,
                     'color_quantization',
                     label='Color quantization:',
                     orientation=Qt.Horizontal,
                     sendSelectedValue=True,
                     items=('equidistant', 'logarithmic', 'quantile',
                            'k-means'),
                     callback=_set_quantization)

        self._opacity_slider = gui.hSlider(
            box,
            self,
            'opacity',
            None,
            20,
            100,
            5,
            label='Opacity:',
            labelFormat=' %d%%',
            callback=lambda: self.map.set_opacity(self.opacity))

        gui.checkBox(box,
                     self,
                     'show_legend',
                     label='Show legend',
                     callback=lambda: self.map.toggle_legend(self.show_legend))
        gui.checkBox(
            box,
            self,
            'show_labels',
            label='Show map labels',
            callback=lambda: self.map.toggle_map_labels(self.show_labels))
        gui.checkBox(box,
                     self,
                     'show_details',
                     label='Show region details in tooltip',
                     callback=lambda: self.map.toggle_tooltip_details(
                         self.show_details))

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        self.map.toggle_legend(self.show_legend)
        self.map.toggle_map_labels(self.show_labels)
        self.map.toggle_tooltip_details(self.show_details)
        self.map.set_quantization(self.color_quantization)
        self.map.set_color_steps(self.color_steps)
        self.map.set_opacity(self.opacity)

    def __del__(self):
        self.progressBarFinished(None)
        self.map = None

    def commit(self):
        self.send(
            'Selected Data', self.data[self._indices]
            if self.data is not None and self.selection else None)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.data, self._indices))

    def set_data(self, data):
        self.data = data

        self.closeContext()

        self.clear()

        if data is None:
            return

        self._combo_attr.model().set_domain(data.domain)
        self._latlon_model.set_domain(data.domain)

        lat, lon = find_lat_lon(data)
        if lat or lon:
            self._combo_lat.setCurrentIndex(
                -1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(
                -1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name if lat else None
            self.lon_attr = lon.name if lon else None
            if lat and lon:
                self.latlon = np.c_[
                    self.data.get_column_view(self.lat_attr)[0],
                    self.data.get_column_view(self.lon_attr)[0]]

        if data.domain.class_var:
            self.attr = data.domain.class_var.name
        else:
            self.attr = self._combo_attr.itemText(0)

        self.openContext(data)

        if self.selection:
            self.map.preset_region_selection(self.selection)
        self.aggregate()
        self.map.fit_to_bounds()

    def aggregate(self):
        if self.latlon is None or self.attr not in self.data.domain:
            self.clear(caches=False)
            return

        attr = self.data.domain[self.attr]

        if attr.is_discrete and self.agg_func not in self.AGG_FUNCS_DISCRETE:
            self.Error.aggregation_discrete(', '.join(
                map(str.lower, self.AGG_FUNCS_DISCRETE)))
            self.Warning.logarithmic_nonpositive.clear()
            self.clear(caches=False)
            return
        else:
            self.Error.aggregation_discrete.clear()

        try:
            regions, adm0, result, self.map.bounds = \
                self.get_grouped(self.lat_attr, self.lon_attr, self.admin, self.attr, self.agg_func)
        except ValueError:
            # This might happen if widget scheme File→Choropleth, and
            # some attr is selected in choropleth, and then the same attr
            # is set to string attr in File and dataset reloaded.
            # Our "dataflow" arch can suck my balls
            return
        discrete_values = list(
            attr.values) if attr.is_discrete and not self.agg_func.startswith(
                'Count') else []

        self.result_min_nonpositive = attr.is_continuous and result.min() <= 0
        force_quantization = self.color_quantization.startswith(
            'log') and self.result_min_nonpositive
        self.Warning.logarithmic_nonpositive(shown=force_quantization)

        repr_time = isinstance(
            attr,
            TimeVariable) and self.agg_func not in self.AGG_FUNCS_CANT_TIME

        self.map.exposeObject(
            'results',
            dict(
                discrete=discrete_values,
                colors=[
                    color_to_hex(i)
                    for i in (attr.colors if discrete_values else (
                        (0, 0,
                         255), (255, 255,
                                0)) if attr.is_discrete else attr.colors[:-1])
                ],  # ???
                regions=list(adm0),
                attr=attr.name,
                have_nonpositive=self.result_min_nonpositive
                or discrete_values,
                values=result.to_dict(),
                repr_vals=result.map(attr.repr_val).to_dict()
                if repr_time else {},
                minmax=([result.min(), result.max()]
                        if attr.is_discrete and not discrete_values else [
                            attr.repr_val(result.min()),
                            attr.repr_val(result.max())
                        ] if repr_time or not discrete_values else [])))

        self.map.evalJS('replot();')

    @memoize_method(3)
    def get_regions(self, lat_attr, lon_attr, admin):
        latlon = np.c_[self.data.get_column_view(lat_attr)[0],
                       self.data.get_column_view(lon_attr)[0]]
        regions = latlon2region(latlon, admin)
        adm0 = ({'0'} if admin == 0 else {
            '1-' + a3
            for a3 in (i.get('adm0_a3') for i in regions) if a3
        } if admin == 1 else {('2-' if a3 in ADMIN2_COUNTRIES else '1-') + a3
                              for a3 in (i.get('adm0_a3') for i in regions)
                              if a3})
        ids = [i.get('_id') for i in regions]
        self.ids = pd.Series(ids)
        regions = set(ids) - {None}
        bounds = get_bounding_rect(regions) if regions else None
        return regions, ids, adm0, bounds

    @memoize_method(6)
    def get_grouped(self, lat_attr, lon_attr, admin, attr, agg_func):
        log.debug('Grouping %s(%s) by (%s, %s; admin%d)', agg_func, attr,
                  lat_attr, lon_attr, admin)
        regions, ids, adm0, bounds = self.get_regions(lat_attr, lon_attr,
                                                      admin)
        attr = self.data.domain[attr]
        result = pd.Series(self.data.get_column_view(attr)[0], dtype=float)\
            .groupby(ids)\
            .agg(self.AGG_FUNCS_TRANSFORM.get(agg_func, agg_func.lower()))
        return regions, adm0, result, bounds

    def clear(self, caches=True):
        if caches:
            try:
                self.get_regions.cache_clear()
                self.get_grouped.cache_clear()
            except AttributeError:
                pass  # back-compat https://github.com/biolab/orange3/pull/2229
        self.selection = []
        self.map.exposeObject('results', {})
        self.map.evalJS('replot();')
示例#2
0
class OWPythagoreanForest(OWWidget):
    name = 'Pythagorean Forest'
    description = 'Pythagorean forest for visualising random forests.'
    icon = 'icons/PythagoreanForest.svg'
    settings_version = 2
    keywords = ["fractal"]

    priority = 1001

    class Inputs:
        random_forest = Input("Random Forest",
                              RandomForestModel,
                              replaces=["Random forest"])

    class Outputs:
        tree = Output("Tree", TreeModel)

    # Enable the save as feature
    graph_name = 'scene'

    # Settings
    settingsHandler = settings.DomainContextHandler()

    depth_limit = settings.ContextSetting(10)
    target_class_index = settings.ContextSetting(0)
    size_calc_idx = settings.Setting(0)
    zoom = settings.Setting(200)

    selected_index = settings.ContextSetting(None)

    SIZE_CALCULATION = [
        ('Normal', lambda x: x),
        ('Square root', lambda x: sqrt(x)),
        ('Logarithmic', lambda x: log(x + 1)),
    ]

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2:
            settings.pop('selected_tree_index', None)
            v1_min, v1_max = 20, 150
            v2_min, v2_max = 100, 400
            ratio = (v2_max - v2_min) / (v1_max - v1_min)
            settings['zoom'] = int(ratio * (settings['zoom'] - v1_min) +
                                   v2_min)

    def __init__(self):
        super().__init__()
        self.rf_model = None
        self.forest = None
        self.instances = None

        self.color_palette = None

        # CONTROL AREA
        # Tree info area
        box_info = gui.widgetBox(self.controlArea, 'Forest')
        self.ui_info = gui.widgetLabel(box_info)

        # Display controls area
        box_display = gui.widgetBox(self.controlArea, 'Display')
        self.ui_depth_slider = gui.hSlider(
            box_display,
            self,
            'depth_limit',
            label='Depth',
            ticks=False,
        )  # type: QSlider
        self.ui_target_class_combo = gui.comboBox(
            box_display,
            self,
            'target_class_index',
            label='Target class',
            orientation=Qt.Horizontal,
            items=[],
            contentsLength=8,
        )  # type: gui.OrangeComboBox
        self.ui_size_calc_combo = gui.comboBox(
            box_display,
            self,
            'size_calc_idx',
            label='Size',
            orientation=Qt.Horizontal,
            items=list(zip(*self.SIZE_CALCULATION))[0],
            contentsLength=8,
        )  # type: gui.OrangeComboBox
        self.ui_zoom_slider = gui.hSlider(
            box_display,
            self,
            'zoom',
            label='Zoom',
            ticks=False,
            minValue=100,
            maxValue=400,
            createLabel=False,
            intOnly=False,
        )  # type: QSlider

        # Stretch to fit the rest of the unsused area
        gui.rubber(self.controlArea)

        self.controlArea.setSizePolicy(QSizePolicy.Preferred,
                                       QSizePolicy.Expanding)

        # MAIN AREA
        self.forest_model = PythagoreanForestModel(parent=self)
        self.forest_model.update_item_size(self.zoom)
        self.ui_depth_slider.valueChanged.connect(
            self.forest_model.update_depth)
        self.ui_target_class_combo.currentIndexChanged.connect(
            self.forest_model.update_target_class)
        self.ui_zoom_slider.valueChanged.connect(
            self.forest_model.update_item_size)
        self.ui_size_calc_combo.currentIndexChanged.connect(
            self.forest_model.update_size_calc)

        self.list_delegate = PythagorasTreeDelegate(parent=self)
        self.list_view = ClickToClearSelectionListView(parent=self)
        self.list_view.setWrapping(True)
        self.list_view.setFlow(QListView.LeftToRight)
        self.list_view.setResizeMode(QListView.Adjust)
        self.list_view.setModel(self.forest_model)
        self.list_view.setItemDelegate(self.list_delegate)
        self.list_view.setSpacing(2)
        self.list_view.setSelectionMode(QListView.SingleSelection)
        self.list_view.selectionModel().selectionChanged.connect(self.commit)
        self.list_view.setUniformItemSizes(True)
        self.mainArea.layout().addWidget(self.list_view)

        self.resize(800, 500)

        # Clear to set sensible default values
        self.clear()

    @Inputs.random_forest
    def set_rf(self, model=None):
        """When a different forest is given."""
        self.closeContext()
        self.clear()
        self.rf_model = model

        if model is not None:
            self.forest = self._get_forest_adapter(self.rf_model)
            self.forest_model[:] = self.forest.trees
            self.instances = model.instances

            self._update_info_box()
            self._update_target_class_combo()
            self._update_depth_slider()

        self.openContext(model)
        # Restore item selection
        if self.selected_index is not None:
            index = self.list_view.model().index(self.selected_index)
            selection = QItemSelection(index, index)
            self.list_view.selectionModel().select(
                selection, QItemSelectionModel.ClearAndSelect)

    def clear(self):
        """Clear all relevant data from the widget."""
        self.rf_model = None
        self.forest = None
        self.forest_model.clear()
        self.selected_index = None

        self._clear_info_box()
        self._clear_target_class_combo()
        self._clear_depth_slider()

    def _update_info_box(self):
        self.ui_info.setText('Trees: {}'.format(len(self.forest.trees)))

    def _update_depth_slider(self):
        self.depth_limit = self._get_max_depth()

        self.ui_depth_slider.parent().setEnabled(True)
        self.ui_depth_slider.setMaximum(self.depth_limit)
        self.ui_depth_slider.setValue(self.depth_limit)

    def _update_target_class_combo(self):
        self._clear_target_class_combo()
        label = [
            x for x in self.ui_target_class_combo.parent().children()
            if isinstance(x, QLabel)
        ][0]

        if self.instances.domain.has_discrete_class:
            label_text = 'Target class'
            values = [
                c.title() for c in self.instances.domain.class_vars[0].values
            ]
            values.insert(0, 'None')
        else:
            label_text = 'Node color'
            values = list(ContinuousTreeNode.COLOR_METHODS.keys())
        label.setText(label_text)
        self.ui_target_class_combo.addItems(values)
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_info_box(self):
        self.ui_info.setText('No forest on input.')

    def _clear_target_class_combo(self):
        self.ui_target_class_combo.clear()
        self.target_class_index = 0
        self.ui_target_class_combo.setCurrentIndex(self.target_class_index)

    def _clear_depth_slider(self):
        self.ui_depth_slider.parent().setEnabled(False)
        self.ui_depth_slider.setMaximum(0)

    def _get_max_depth(self):
        return max(tree.max_depth for tree in self.forest.trees)

    def _get_forest_adapter(self, model):
        return SklRandomForestAdapter(model)

    def onDeleteWidget(self):
        """When deleting the widget."""
        super().onDeleteWidget()
        self.clear()

    def commit(self, selection: QItemSelection) -> None:
        """Commit the selected tree to output."""
        selected_indices = selection.indexes()

        if not len(selected_indices):
            self.selected_index = None
            self.Outputs.tree.send(None)
            return

        # We only allow selecting a single tree so there will always be one index
        self.selected_index = selected_indices[0].row()

        tree = self.rf_model.trees[self.selected_index]
        tree.instances = self.instances
        tree.meta_target_class_index = self.target_class_index
        tree.meta_size_calc_idx = self.size_calc_idx
        tree.meta_depth_limit = self.depth_limit

        self.Outputs.tree.send(tree)

    def send_report(self):
        """Send report."""
        self.report_plot()
示例#3
0
class OWDistributions(widget.OWWidget):
    name = "Distributions"
    description = "Display value distributions of a data feature in a graph."
    icon = "icons/Distribution.svg"
    priority = 120

    class Inputs:
        data = Input("Data", Orange.data.Table, doc="Set the input data set")

    settingsHandler = settings.DomainContextHandler(
        match_values=settings.DomainContextHandler.MATCH_VALUES_ALL)
    #: Selected variable index
    variable_idx = settings.ContextSetting(-1)
    #: Selected group variable
    groupvar_idx = settings.ContextSetting(0)

    relative_freq = settings.Setting(False)
    disc_cont = settings.Setting(False)

    smoothing_index = settings.Setting(5)
    show_prob = settings.ContextSetting(0)

    graph_name = "plot"

    ASH_HIST = 50

    bins = [2, 3, 4, 5, 8, 10, 12, 15, 20, 30, 50]
    smoothing_facs = list(reversed([0.1, 0.2, 0.4, 0.6, 0.8, 1, 1.5, 2, 4, 6, 10]))

    def __init__(self):
        super().__init__()
        self.data = None

        self.distributions = None
        self.contingencies = None
        self.var = self.cvar = None
        varbox = gui.vBox(self.controlArea, "Variable")

        self.varmodel = itemmodels.VariableListModel()
        self.groupvarmodel = []

        self.varview = QListView(
            selectionMode=QListView.SingleSelection)
        self.varview.setSizePolicy(
            QSizePolicy.Minimum, QSizePolicy.Expanding)
        self.varview.setModel(self.varmodel)
        self.varview.setSelectionModel(
            itemmodels.ListSingleSelectionModel(self.varmodel))
        self.varview.selectionModel().selectionChanged.connect(
            self._on_variable_idx_changed)
        varbox.layout().addWidget(self.varview)

        box = gui.vBox(self.controlArea, "Precision")

        gui.separator(self.controlArea, 4, 4)

        box2 = gui.hBox(box)
        self.l_smoothing_l = gui.widgetLabel(box2, "Smooth")
        gui.hSlider(box2, self, "smoothing_index",
                    minValue=0, maxValue=len(self.smoothing_facs) - 1,
                    callback=self._on_set_smoothing, createLabel=False)
        self.l_smoothing_r = gui.widgetLabel(box2, "Precise")

        self.cb_disc_cont = gui.checkBox(
            gui.indentedBox(box, sep=4),
            self, "disc_cont", "Bin numeric variables",
            callback=self._on_groupvar_idx_changed,
            tooltip="Show numeric variables as categorical.")

        box = gui.vBox(self.controlArea, "Group by")
        self.icons = gui.attributeIconDict
        self.groupvarview = gui.comboBox(
            box, self, "groupvar_idx",
            callback=self._on_groupvar_idx_changed,
            valueType=str, contentsLength=12)
        box2 = gui.indentedBox(box, sep=4)
        self.cb_rel_freq = gui.checkBox(
            box2, self, "relative_freq", "Show relative frequencies",
            callback=self._on_relative_freq_changed,
            tooltip="Normalize probabilities so that probabilities "
                    "for each group-by value sum to 1.")
        gui.separator(box2)
        self.cb_prob = gui.comboBox(
            box2, self, "show_prob", label="Show probabilities:",
            orientation=Qt.Horizontal,
            callback=self._on_relative_freq_changed,
            tooltip="Show probabilities for a chosen group-by value "
                    "(at each point probabilities for all group-by values sum to 1).")

        self.plotview = pg.PlotWidget(background=None)
        self.plotview.setRenderHint(QPainter.Antialiasing)
        self.mainArea.layout().addWidget(self.plotview)
        w = QLabel()
        w.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        self.mainArea.layout().addWidget(w, Qt.AlignCenter)
        self.ploti = pg.PlotItem()
        self.plot = self.ploti.vb
        self.ploti.hideButtons()
        self.plotview.setCentralItem(self.ploti)

        self.plot_prob = pg.ViewBox()
        self.ploti.hideAxis('right')
        self.ploti.scene().addItem(self.plot_prob)
        self.ploti.getAxis("right").linkToView(self.plot_prob)
        self.ploti.getAxis("right").setLabel("Probability")
        self.plot_prob.setZValue(10)
        self.plot_prob.setXLink(self.ploti)
        self.update_views()
        self.ploti.vb.sigResized.connect(self.update_views)
        self.plot_prob.setRange(yRange=[0, 1])

        def disable_mouse(plot):
            plot.setMouseEnabled(False, False)
            plot.setMenuEnabled(False)

        disable_mouse(self.plot)
        disable_mouse(self.plot_prob)

        self.tooltip_items = []
        self.plot.scene().installEventFilter(
            HelpEventDelegate(self.help_event, self))

        pen = QPen(self.palette().color(QPalette.Text))
        for axis in ("left", "bottom"):
            self.ploti.getAxis(axis).setPen(pen)

        self._legend = LegendItem()
        self._legend.setParentItem(self.plot)
        self._legend.hide()
        self._legend.anchor((1, 0), (1, 0))

    def update_views(self):
        self.plot_prob.setGeometry(self.plot.sceneBoundingRect())
        self.plot_prob.linkedViewChanged(self.plot, self.plot_prob.XAxis)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.warning()
        self.data = data
        self.distributions = None
        self.contingencies = None
        if self.data is not None:
            if not self.data:
                self.warning("Empty input data cannot be visualized")
                return
            domain = self.data.domain
            self.varmodel[:] = list(domain.variables) + \
                               [meta for meta in domain.metas
                                if meta.is_continuous or meta.is_discrete]
            self.groupvarview.clear()
            self.groupvarmodel = \
                ["(None)"] + [var for var in domain.variables if var.is_discrete] + \
                [meta for meta in domain.metas if meta.is_discrete]
            self.groupvarview.addItem("(None)")
            for var in self.groupvarmodel[1:]:
                self.groupvarview.addItem(self.icons[var], var.name)
            if domain.has_discrete_class:
                self.groupvar_idx = \
                    self.groupvarmodel[1:].index(domain.class_var) + 1
            self.openContext(domain)
            self.variable_idx = min(max(self.variable_idx, 0),
                                    len(self.varmodel) - 1)
            self.groupvar_idx = min(max(self.groupvar_idx, 0),
                                    len(self.groupvarmodel) - 1)
            itemmodels.select_row(self.varview, self.variable_idx)
            self._setup()

    def clear(self):
        self.plot.clear()
        self.plot_prob.clear()
        self.varmodel[:] = []
        self.groupvarmodel = []
        self.variable_idx = -1
        self.groupvar_idx = 0
        self._legend.clear()
        self._legend.hide()
        self.groupvarview.clear()
        self.cb_prob.clear()

    def _setup_smoothing(self):
        if not self.disc_cont and self.var and self.var.is_continuous:
            self.cb_disc_cont.setText("Bin numeric variables")
            self.l_smoothing_l.setText("Smooth")
            self.l_smoothing_r.setText("Precise")
        else:
            self.cb_disc_cont.setText("Bin numeric variables into {} bins".
                                      format(self.bins[self.smoothing_index]))
            self.l_smoothing_l.setText(" " + str(self.bins[0]))
            self.l_smoothing_r.setText(" " + str(self.bins[-1]))

    @property
    def smoothing_factor(self):
        return self.smoothing_facs[self.smoothing_index]

    def _setup(self):
        self.plot.clear()
        self.plot_prob.clear()
        self._legend.clear()
        self._legend.hide()

        varidx = self.variable_idx
        self.var = self.cvar = None
        if varidx >= 0:
            self.var = self.varmodel[varidx]
        if self.groupvar_idx > 0:
            self.cvar = self.groupvarmodel[self.groupvar_idx]
            self.cb_prob.clear()
            self.cb_prob.addItem("(None)")
            self.cb_prob.addItems(self.cvar.values)
            self.cb_prob.addItem("(All)")
            self.show_prob = min(max(self.show_prob, 0),
                                 len(self.cvar.values) + 1)
        data = self.data
        self._setup_smoothing()
        if self.var is None:
            return
        if self.disc_cont:
            domain = Orange.data.Domain(
                [self.var, self.cvar] if self.cvar else [self.var])
            data = Orange.data.Table(domain, data)
            disc = EqualWidth(n=self.bins[self.smoothing_index])
            data = Discretize(method=disc, remove_const=False)(data)
            self.var = data.domain[0]
        self.set_left_axis_name()
        self.enable_disable_rel_freq()
        if self.cvar:
            self.contingencies = \
                contingency.get_contingency(data, self.var, self.cvar)
            self.display_contingency()
        else:
            self.distributions = \
                distribution.get_distribution(data, self.var)
            self.display_distribution()
        self.plot.autoRange()

    def help_event(self, ev):
        self.plot.mapSceneToView(ev.scenePos())
        ctooltip = []
        for vb, item in self.tooltip_items:
            mouse_over_curve = isinstance(item, pg.PlotCurveItem) \
                and item.mouseShape().contains(vb.mapSceneToView(ev.scenePos()))
            mouse_over_bar = isinstance(item, DistributionBarItem) \
                and item.boundingRect().contains(vb.mapSceneToView(ev.scenePos()))
            if mouse_over_curve or mouse_over_bar:
                ctooltip.append(item.tooltip)
        if ctooltip:
            QToolTip.showText(ev.screenPos(), "\n\n".join(ctooltip), widget=self.plotview)
            return True
        return False

    def display_distribution(self):
        dist = self.distributions
        var = self.var
        if dist is None or not len(dist):
            return
        self.plot.clear()
        self.plot_prob.clear()
        self.ploti.hideAxis('right')
        self.tooltip_items = []

        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(var.name)
        bottomaxis.resizeEvent()

        self.set_left_axis_name()
        if var and var.is_continuous:
            bottomaxis.setTicks(None)
            if not len(dist[0]):
                return
            edges, curve = ash_curve(dist, None, m=OWDistributions.ASH_HIST,
                                     smoothing_factor=self.smoothing_factor)
            edges = edges + (edges[1] - edges[0])/2
            edges = edges[:-1]
            item = pg.PlotCurveItem()
            pen = QPen(QBrush(Qt.white), 3)
            pen.setCosmetic(True)
            item.setData(edges, curve, antialias=True, stepMode=False,
                         fillLevel=0, brush=QBrush(Qt.gray), pen=pen)
            self.plot.addItem(item)
            item.tooltip = "Density"
            self.tooltip_items.append((self.plot, item))
        else:
            bottomaxis.setTicks([list(enumerate(var.values))])
            for i, w in enumerate(dist):
                geom = QRectF(i - 0.33, 0, 0.66, w)
                item = DistributionBarItem(geom, [1.0],
                                           [QColor(128, 128, 128)])
                self.plot.addItem(item)
                item.tooltip = "Frequency for %s: %r" % (var.values[i], w)
                self.tooltip_items.append((self.plot, item))

    def _on_relative_freq_changed(self):
        self.set_left_axis_name()
        if self.cvar and self.cvar.is_discrete:
            self.display_contingency()
        else:
            self.display_distribution()
        self.plot.autoRange()

    def display_contingency(self):
        """
        Set the contingency to display.
        """
        cont = self.contingencies
        var, cvar = self.var, self.cvar
        if cont is None or not len(cont):
            return
        self.plot.clear()
        self.plot_prob.clear()
        self._legend.clear()
        self.tooltip_items = []

        if self.show_prob:
            self.ploti.showAxis('right')
        else:
            self.ploti.hideAxis('right')

        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(var.name)
        bottomaxis.resizeEvent()

        cvar_values = cvar.values
        colors = [QColor(*col) for col in cvar.colors]

        if var and var.is_continuous:
            bottomaxis.setTicks(None)

            weights, cols, cvar_values, curves = [], [], [], []
            for i, dist in enumerate(cont):
                v, W = dist
                if len(v):
                    weights.append(numpy.sum(W))
                    cols.append(colors[i])
                    cvar_values.append(cvar.values[i])
                    curves.append(ash_curve(
                        dist, cont, m=OWDistributions.ASH_HIST,
                        smoothing_factor=self.smoothing_factor))
            weights = numpy.array(weights)
            sumw = numpy.sum(weights)
            weights /= sumw
            colors = cols
            curves = [(X, Y * w) for (X, Y), w in zip(curves, weights)]

            curvesline = [] #from histograms to lines
            for X, Y in curves:
                X = X + (X[1] - X[0])/2
                X = X[:-1]
                X = numpy.array(X)
                Y = numpy.array(Y)
                curvesline.append((X, Y))

            for t in ["fill", "line"]:
                curve_data = list(zip(curvesline, colors, weights, cvar_values))
                for (X, Y), color, w, cval in reversed(curve_data):
                    item = pg.PlotCurveItem()
                    pen = QPen(QBrush(color), 3)
                    pen.setCosmetic(True)
                    color = QColor(color)
                    color.setAlphaF(0.2)
                    item.setData(X, Y/(w if self.relative_freq else 1),
                                 antialias=True, stepMode=False,
                                 fillLevel=0 if t == "fill" else None,
                                 brush=QBrush(color), pen=pen)
                    self.plot.addItem(item)
                    if t == "line":
                        item.tooltip = "{}\n{}={}".format(
                            "Normalized density " if self.relative_freq else "Density ",
                            cvar.name, cval)
                        self.tooltip_items.append((self.plot, item))

            if self.show_prob:
                all_X = numpy.array(numpy.unique(numpy.hstack([X for X, _ in curvesline])))
                inter_X = numpy.array(numpy.linspace(all_X[0], all_X[-1], len(all_X)*2))
                curvesinterp = [numpy.interp(inter_X, X, Y) for (X, Y) in curvesline]
                sumprob = numpy.sum(curvesinterp, axis=0)
                legal = sumprob > 0.05 * numpy.max(sumprob)

                i = len(curvesinterp) + 1
                show_all = self.show_prob == i
                for Y, color, cval in reversed(list(zip(curvesinterp, colors, cvar_values))):
                    i -= 1
                    if show_all or self.show_prob == i:
                        item = pg.PlotCurveItem()
                        pen = QPen(QBrush(color), 3, style=Qt.DotLine)
                        pen.setCosmetic(True)
                        prob = Y[legal] / sumprob[legal]
                        item.setData(
                            inter_X[legal], prob, antialias=True, stepMode=False,
                            fillLevel=None, brush=None, pen=pen)
                        self.plot_prob.addItem(item)
                        item.tooltip = "Probability that \n" + cvar.name + "=" + cval
                        self.tooltip_items.append((self.plot_prob, item))

        elif var and var.is_discrete:
            bottomaxis.setTicks([list(enumerate(var.values))])

            cont = numpy.array(cont)

            maxh = 0 #maximal column height
            maxrh = 0 #maximal relative column height
            scvar = cont.sum(axis=1)
            #a cvar with sum=0 with allways have distribution counts 0,
            #therefore we can divide it by anything
            scvar[scvar == 0] = 1
            for i, (value, dist) in enumerate(zip(var.values, cont.T)):
                maxh = max(maxh, max(dist))
                maxrh = max(maxrh, max(dist/scvar))

            for i, (value, dist) in enumerate(zip(var.values, cont.T)):
                dsum = sum(dist)
                geom = QRectF(i - 0.333, 0, 0.666,
                              maxrh if self.relative_freq else maxh)
                if self.show_prob:
                    prob = dist / dsum
                    ci = 1.96 * numpy.sqrt(prob * (1 - prob) / dsum)
                else:
                    ci = None
                item = DistributionBarItem(geom, dist/scvar/maxrh
                                           if self.relative_freq
                                           else dist/maxh, colors)
                self.plot.addItem(item)
                tooltip = "\n".join(
                    "%s: %.*f" % (n, 3 if self.relative_freq else 1, v)
                    for n, v in zip(cvar_values, dist/scvar if self.relative_freq else dist))
                item.tooltip = "{} ({}={}):\n{}".format(
                    "Normalized frequency " if self.relative_freq else "Frequency ",
                    cvar.name, value, tooltip)
                self.tooltip_items.append((self.plot, item))

                if self.show_prob:
                    item.tooltip += "\n\nProbabilities:"
                    for ic, a in enumerate(dist):
                        if self.show_prob - 1 != ic and \
                                self.show_prob - 1 != len(dist):
                            continue
                        position = -0.333 + ((ic+0.5)*0.666/len(dist))
                        if dsum < 1e-6:
                            continue
                        prob = a / dsum
                        if not 1e-6 < prob < 1 - 1e-6:
                            continue
                        ci = 1.96 * sqrt(prob * (1 - prob) / dsum)
                        item.tooltip += "\n%s: %.3f ± %.3f" % (cvar_values[ic], prob, ci)
                        mark = pg.ScatterPlotItem()
                        errorbar = pg.ErrorBarItem()
                        pen = QPen(QBrush(QColor(0)), 1)
                        pen.setCosmetic(True)
                        errorbar.setData(x=[i+position], y=[prob],
                                         bottom=min(numpy.array([ci]), prob),
                                         top=min(numpy.array([ci]), 1 - prob),
                                         beam=numpy.array([0.05]),
                                         brush=QColor(1), pen=pen)
                        mark.setData([i+position], [prob], antialias=True, symbol="o",
                                     fillLevel=None, pxMode=True, size=10,
                                     brush=QColor(colors[ic]), pen=pen)
                        self.plot_prob.addItem(errorbar)
                        self.plot_prob.addItem(mark)

        for color, name in zip(colors, cvar_values):
            self._legend.addItem(
                ScatterPlotItem(pen=color, brush=color, size=10, shape="s"),
                escape(name)
            )
        self._legend.show()

    def set_left_axis_name(self):
        leftaxis = self.ploti.getAxis("left")
        set_label = leftaxis.setLabel
        if self.var and self.var.is_continuous:
            set_label(["Density", "Relative density"]
                      [self.cvar is not None and self.relative_freq])
        else:
            set_label(["Frequency", "Relative frequency"]
                      [self.cvar is not None and self.relative_freq])
        leftaxis.resizeEvent()

    def enable_disable_rel_freq(self):
        self.cb_prob.setDisabled(self.var is None or self.cvar is None)
        self.cb_rel_freq.setDisabled(
            self.var is None or self.cvar is None)

    def _on_variable_idx_changed(self):
        self.variable_idx = selected_index(self.varview)
        self._setup()

    def _on_groupvar_idx_changed(self):
        self._setup()

    def _on_set_smoothing(self):
        self._setup()

    def onDeleteWidget(self):
        self.plot.clear()
        super().onDeleteWidget()

    def get_widget_name_extension(self):
        if self.variable_idx >= 0:
            return self.varmodel[self.variable_idx]

    def send_report(self):
        self.plotview.scene().setSceneRect(self.plotview.sceneRect())
        if self.variable_idx < 0:
            return
        self.report_plot()
        text = "Distribution of '{}'".format(
            self.varmodel[self.variable_idx])
        if self.groupvar_idx:
            group_var = self.groupvarmodel[self.groupvar_idx]
            prob = self.cb_prob
            indiv_probs = 0 < prob.currentIndex() < prob.count() - 1
            if not indiv_probs or self.relative_freq:
                text += " grouped by '{}'".format(group_var)
                if self.relative_freq:
                    text += " (relative frequencies)"
            if indiv_probs:
                text += "; probabilites for '{}={}'".format(
                    group_var, prob.currentText())
        self.report_caption(text)
示例#4
0
class OWMap(widget.OWWidget):
    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/GeoMap.svg"
    priority = 100

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        learner = Input("Learner", Learner)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    replaces = [
        "Orange.widgets.visualize.owmap.OWMap",
    ]

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    autocommit = settings.Setting(True)
    tile_provider = settings.Setting('Black and white')
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    class_attr = settings.ContextSetting('(None)')
    color_attr = settings.ContextSetting('')
    label_attr = settings.ContextSetting('')
    shape_attr = settings.ContextSetting('')
    size_attr = settings.ContextSetting('')
    opacity = settings.Setting(100)
    zoom = settings.Setting(100)
    jittering = settings.Setting(0)
    cluster_points = settings.Setting(False)
    show_legend = settings.Setting(True)

    TILE_PROVIDERS = OrderedDict((
        ('Black and white', 'OpenStreetMap.BlackAndWhite'),
        ('OpenStreetMap', 'OpenStreetMap.Mapnik'),
        ('Topographic', 'OpenTopoMap'),
        ('Satellite', 'Esri.WorldImagery'),
        ('Print', 'Stamen.TonerLite'),
        ('Dark', 'CartoDB.DarkMatter'),
        ('Watercolor', 'Stamen.Watercolor'),
    ))

    class Error(widget.OWWidget.Error):
        model_error = widget.Msg("Error predicting: {}")
        learner_error = widget.Msg("Error modelling: {}")

    class Warning(widget.OWWidget.Warning):
        all_nan_slice = widget.Msg('Latitude and/or longitude has no defined values (is all-NaN)')

    UserAdviceMessages = [
        widget.Message(
            'Select markers by holding <b><kbd>Shift</kbd></b> key and dragging '
            'a rectangle around them. Clear the selection by clicking anywhere.',
            'shift-selection')
    ]

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletMap(self)  # type: LeafletMap
        self.mainArea.layout().addWidget(map)
        self.selection = None
        self.data = None
        self.learner = None

        def selectionChanged(indices):
            self.selection = self.data[indices] if self.data is not None and indices else None
            self._indices = indices
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        def _set_map_provider():
            map.set_map_provider(self.TILE_PROVIDERS[self.tile_provider])

        box = gui.vBox(self.controlArea, 'Map')
        gui.comboBox(box, self, 'tile_provider',
                     orientation=Qt.Horizontal,
                     label='Map:',
                     items=tuple(self.TILE_PROVIDERS.keys()),
                     sendSelectedValue=True,
                     callback=_set_map_provider)

        self._latlon_model = DomainModel(
            parent=self, valid_types=ContinuousVariable)
        self._class_model = DomainModel(
            parent=self, placeholder='(None)', valid_types=DomainModel.PRIMITIVE)
        self._color_model = DomainModel(
            parent=self, placeholder='(Same color)', valid_types=DomainModel.PRIMITIVE)
        self._shape_model = DomainModel(
            parent=self, placeholder='(Same shape)', valid_types=DiscreteVariable)
        self._size_model = DomainModel(
            parent=self, placeholder='(Same size)', valid_types=ContinuousVariable)
        self._label_model = DomainModel(
            parent=self, placeholder='(No labels)')

        def _set_lat_long():
            self.map.set_data(self.data, self.lat_attr, self.lon_attr)
            self.train_model()

        self._combo_lat = combo = gui.comboBox(
            box, self, 'lat_attr', orientation=Qt.Horizontal,
            label='Latitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)
        self._combo_lon = combo = gui.comboBox(
            box, self, 'lon_attr', orientation=Qt.Horizontal,
            label='Longitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)

        def _toggle_legend():
            self.map.toggle_legend(self.show_legend)

        gui.checkBox(box, self, 'show_legend', label='Show legend',
                     callback=_toggle_legend)

        box = gui.vBox(self.controlArea, 'Overlay')
        self._combo_class = combo = gui.comboBox(
            box, self, 'class_attr', orientation=Qt.Horizontal,
            label='Target:', sendSelectedValue=True, callback=self.train_model
        )
        self.controls.class_attr.setModel(self._class_model)
        self.set_learner(self.learner)

        box = gui.vBox(self.controlArea, 'Points')
        self._combo_color = combo = gui.comboBox(
            box, self, 'color_attr',
            orientation=Qt.Horizontal,
            label='Color:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_color(self.color_attr))
        combo.setModel(self._color_model)
        self._combo_label = combo = gui.comboBox(
            box, self, 'label_attr',
            orientation=Qt.Horizontal,
            label='Label:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_label(self.label_attr))
        combo.setModel(self._label_model)
        self._combo_shape = combo = gui.comboBox(
            box, self, 'shape_attr',
            orientation=Qt.Horizontal,
            label='Shape:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_shape(self.shape_attr))
        combo.setModel(self._shape_model)
        self._combo_size = combo = gui.comboBox(
            box, self, 'size_attr',
            orientation=Qt.Horizontal,
            label='Size:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_size(self.size_attr))
        combo.setModel(self._size_model)

        def _set_opacity():
            map.set_marker_opacity(self.opacity)

        def _set_zoom():
            map.set_marker_size_coefficient(self.zoom)

        def _set_jittering():
            map.set_jittering(self.jittering)

        def _set_clustering():
            map.set_clustering(self.cluster_points)

        self._opacity_slider = gui.hSlider(
            box, self, 'opacity', None, 1, 100, 5,
            label='Opacity:', labelFormat=' %d%%',
            callback=_set_opacity)
        self._zoom_slider = gui.valueSlider(
            box, self, 'zoom', None, values=(20, 50, 100, 200, 300, 400, 500, 700, 1000),
            label='Symbol size:', labelFormat=' %d%%',
            callback=_set_zoom)
        self._jittering = gui.valueSlider(
            box, self, 'jittering', label='Jittering:', values=(0, .5, 1, 2, 5),
            labelFormat=' %.1f%%', ticks=True,
            callback=_set_jittering)
        self._clustering_check = gui.checkBox(
            box, self, 'cluster_points', label='Cluster points',
            callback=_set_clustering)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        QTimer.singleShot(0, _set_map_provider)
        QTimer.singleShot(0, _toggle_legend)
        QTimer.singleShot(0, _set_opacity)
        QTimer.singleShot(0, _set_zoom)
        QTimer.singleShot(0, _set_jittering)
        QTimer.singleShot(0, _set_clustering)

    autocommit = settings.Setting(True)

    def __del__(self):
        self.progressBarFinished(None)
        self.map = None

    def commit(self):
        self.Outputs.selected_data.send(self.selection)
        self.Outputs.annotated_data.send(create_annotated_table(self.data, self._indices))

    @Inputs.data
    def set_data(self, data):
        self.data = data

        self.closeContext()

        if data is None or not len(data):
            return self.clear()

        domain = data is not None and data.domain
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(domain)

        lat, lon = find_lat_lon(data)
        if lat or lon:
            self._combo_lat.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name
            self.lon_attr = lon.name

        if data.domain.class_var:
            self.color_attr = data.domain.class_var.name
        elif len(self._color_model):
            self._combo_color.setCurrentIndex(0)
        if len(self._shape_model):
            self._combo_shape.setCurrentIndex(0)
        if len(self._size_model):
            self._combo_size.setCurrentIndex(0)
        if len(self._label_model):
            self._combo_label.setCurrentIndex(0)
        if len(self._class_model):
            self._combo_class.setCurrentIndex(0)

        self.openContext(data)

        self.map.set_data(self.data, self.lat_attr, self.lon_attr)
        self.map.set_marker_color(self.color_attr, update=False)
        self.map.set_marker_label(self.label_attr, update=False)
        self.map.set_marker_shape(self.shape_attr, update=False)
        self.map.set_marker_size(self.size_attr, update=True)

    @Inputs.data_subset
    def set_subset(self, subset):
        self.map.set_subset_ids(subset.ids if subset is not None else np.array([]))

    def handleNewSignals(self):
        super().handleNewSignals()
        self.train_model()

    @Inputs.learner
    def set_learner(self, learner):
        self.learner = learner
        self.controls.class_attr.setEnabled(learner is not None)
        self.controls.class_attr.setToolTip(
            'Needs a Learner input for modelling.' if learner is None else '')

    def train_model(self):
        model = None
        self.Error.clear()
        if self.data and self.learner and self.class_attr != '(None)':
            domain = self.data.domain
            if self.lat_attr and self.lon_attr and self.class_attr in domain:
                domain = Domain([domain[self.lat_attr], domain[self.lon_attr]],
                                [domain[self.class_attr]])  # I am retarded
                train = Table.from_table(domain, self.data)
                try:
                    model = self.learner(train)
                except Exception as e:
                    self.Error.learner_error(e)
        self.map.set_model(model)

    def disable_some_controls(self, disabled):
        tooltip = (
            "Available when the zoom is close enough to have "
            "<{} points in the viewport.".format(self.map.N_POINTS_PER_ITER)
            if disabled else '')
        for widget in (self._combo_label,
                       self._combo_shape,
                       self._clustering_check):
            widget.setDisabled(disabled)
            widget.setToolTip(tooltip)

    def clear(self):
        self.map.set_data(None, '', '')
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(None)
        self.lat_attr = self.lon_attr = self.class_attr = self.color_attr = \
        self.label_attr = self.shape_attr = self.size_attr = None
示例#5
0
class GeoCodeFromFile(widget.OWWidget):
    name = 'Geocode File'
    description = 'Encode region names into geographical coordinates from a custom file of coordinates - output this widget to an orange data table'
    icon = "icons/Geocoding.svg"
    priority = 40

    class Inputs:
        None

    class Outputs:
        coded_data = Output("Coded Data", Table, default=True)

    settingsHandler = settings.DomainContextHandler()
    resizing_enabled = False

    #the initialisation function for the widget
    def __init__(self):
        super().__init__()
        self.data = None
        self.domainmodels = []
        self.unmatched = []
        top = self.controlArea

        box = gui.vBox(self.controlArea, "Geocode A File")
        fpbox = gui.vBox(self.controlArea, "Selected File")

        gui.button(box, self, label='Pick File', callback=lambda: pick_file())

        gui.button(box,
                   self,
                   label='Resolve Coordinates',
                   callback=lambda: resolve_coords())

        #function to remove the old label and update it
        def refresh_label(path):
            nonlocal fpbox
            fpbox.hide()
            fpbox = None
            fpbox = gui.vBox(self.controlArea, "Selected File")
            if (not path.endswith('xlsx')):
                path = "SELECT AN XLSX FILE"
            gui.label(fpbox, self, path)

        #function to pick a file to geocode - save the path to disk
        def pick_file():
            root = tk.Tk()
            root.withdraw()
            file_path = filedialog.askopenfilename()
            pathwriter = open("xlfp.txt", "w")
            pathwriter.write(file_path)
            pathwriter.close()
            refresh_label(file_path)

        def getd(district, coords):
            try:
                for i in range(len(coords['lat'])):
                    if (coords['city_ascii'][i] == district):
                        #get the index for the coords
                        return i
            except:
                print('Invalid District: ' + district)
                exit()

        #function to resolve coordinates
        def resolve_coords():
            data = None
            try:
                fp = open("xlfp.txt", "r")
                _path = fp.readline()
                df = pd.read_excel(_path)
                #mark empty fields in lat and long with a 0
                df.fillna({'Lat': 0, 'Long': 0}, inplace=True)
                data = df
            except:
                print("Invalid file or Broken Path")
            #read the coordinate file
            try:
                csvPath = path.join(path.dirname(path.dirname(__file__)),
                                    'worldcities.csv')
                df = pd.read_csv(csvPath)
                coordfile = df.to_dict()
            except:
                print(
                    "could not find coordinate file worldcities.csv in main Orange folder"
                )

            #search and update the data
            rindex = 0
            for row in data.itertuples(index=True):
                if (row.Lat == 0.0):
                    indx = getd(row.District, coordfile)

                    #if index is valid then set coords
                    if (indx != None):
                        data.at[rindex, 'Lat'] = coordfile['lat'][indx]
                        data.at[rindex, 'Long'] = coordfile['lng'][indx]
                    if (data.at[rindex, 'Lat'] == 0.0
                            or data.at[rindex, 'Long'] == 0.0):
                        data.at[rindex, 'Lat'] = np.nan
                        data.at[rindex, 'Long'] = np.nan
                rindex += 1
            try:
                writer = pd.ExcelWriter('output.xlsx')
                data.to_excel(writer, 'Sheet1')
                writer.save()
                output = Table.from_file('output.xlsx')
                self.Outputs.coded_data.send(output)
            except:
                print("Unable to write excel file")
示例#6
0
class OWFreeViz(widget.OWWidget):
    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    #: Initialization type
    Circular, Random = 0, 1

    jitter_sizes = [0, 0.1, 0.5, 1, 2]

    settings_version = 2
    settingsHandler = settings.DomainContextHandler()

    radius = settings.Setting(0)
    initialization = settings.Setting(Circular)
    auto_commit = settings.Setting(True)

    resolution = 256
    graph = settings.SettingProvider(OWFreeVizGraph)

    ReplotRequest = QEvent.registerEventType()

    graph_name = "graph.plot_widget.plotItem"


    class Warning(widget.OWWidget.Warning):
        sparse_not_supported = widget.Msg("Sparse data is ignored.")

    class Error(widget.OWWidget.Error):
        no_class_var = widget.Msg("Need a class variable")
        not_enough_class_vars = widget.Msg("Needs discrete class variable " \
                                          "with at lest 2 values")
        features_exceeds_instances = widget.Msg("Algorithm should not be used when " \
                                                "number of features exceeds the number " \
                                                "of instances.")
        too_many_data_instances = widget.Msg("Cannot handle so large data.")
        no_valid_data = widget.Msg("No valid data.")


    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self._subset_mask = None
        self._validmask = None
        self._X = None
        self._Y = None
        self._selection = None
        self.__replot_requested = False

        self.variable_x = ContinuousVariable("freeviz-x")
        self.variable_y = ContinuousVariable("freeviz-y")

        box0 = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWFreeVizGraph(self, box0, "Plot", view_box=FreeVizInteractiveViewBox)
        box0.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        box = gui.widgetBox(self.controlArea, "Optimization", spacing=10)
        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10
        )
        form.addRow(
            "Initialization",
            gui.comboBox(box, self, "initialization",
                         items=["Circular", "Random"],
                         callback=self.reset_initialization)
        )
        box.layout().addLayout(form)

        self.btn_start = gui.button(widget=box, master=self, label="Optimize",
                                    callback=self.toogle_start, enabled=False)

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        g.point_properties_box(self.controlArea)
        self.models = g.points_models

        box = gui.widgetBox(self.controlArea, "Show anchors")
        self.rslider = gui.hSlider(
            box, self, "radius", minValue=0, maxValue=100,
            step=5, label="Radius", createLabel=False, ticks=True,
            callback=self.update_radius)
        self.rslider.setTickInterval(0)
        self.rslider.setPageStep(10)

        box = gui.vBox(self.controlArea, "Plot Properties")

        g.add_widgets([g.JitterSizeSlider], box)
        g.add_widgets([g.ShowLegend,
                       g.ClassDensity,
                       g.LabelOnlySelected],
                      box)

        self.graph.box_zoom_select(self.controlArea)
        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")
        self.graph.zoom_actions(self)
        # FreeViz
        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

        self._new_plotdata()

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def update_radius(self):
        # Update the anchor/axes visibility
        assert not self.plotdata is None
        if self.plotdata.hidecircle is None:
            return

        minradius = self.radius / 100 + 1e-5
        for anchor, item in zip(self.plotdata.anchors,
                                self.plotdata.anchoritem):
            item.setVisible(np.linalg.norm(anchor) > minradius)
        self.plotdata.hidecircle.setRect(
            QRectF(-minradius, -minradius,
                   2 * minradius, 2 * minradius))

    def toogle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            if isinstance(self, OWFreeViz):
                self.btn_start.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def _start(self):
        """
        Start the projection optimization.
        """
        assert not self.plotdata is None

        X, Y = self.plotdata.X, self.plotdata.Y
        anchors = self.plotdata.anchors

        def update_freeviz(interval, initial):
            anchors = initial
            while True:
                res = FreeViz.freeviz(X, Y, scale=False, center=False,
                                      initial=anchors, maxiter=interval)
                _, anchors_new = res[:2]
                yield res[:2]
                if np.allclose(anchors, anchors_new, rtol=1e-5, atol=1e-4):
                    return

                anchors = anchors_new

        interval = 10  # TODO

        self._loop.setCoroutine(
            update_freeviz(interval, anchors))
        self.btn_start.setText("Stop")
        self.progressBarInit(processEvents=False)
        self.setBlocking(True)
        self.setStatusMessage("Optimizing")

    def reset_initialization(self):
        """
        Reset the current 'anchor' initialization, and restart the
        optimization if necessary.
        """
        running = self._loop.isRunning()

        if running:
            self._loop.cancel()

        if self.data is not None:
            self._clear_plot()
            self.setup_plot()

        if running:
            self._start()

    def __set_projection(self, res):
        # Set/update the projection matrix and coordinate embeddings
        # assert self.plotdata is not None, "__set_projection call unexpected"
        assert not self.plotdata is None
        increment = 1  # TODO
        self.progressBarAdvance(
            increment * 100. / MAX_ITERATIONS, processEvents=False)  # TODO
        embedding_coords, projection = res
        self.plotdata.embedding_coords = embedding_coords
        self.plotdata.anchors = projection
        self._update_xy()
        self.update_radius()
        self.update_density()

    def __freeviz_finished(self):
        # Projection optimization has finished
        self.btn_start.setText("Optimize")
        self.setStatusMessage("")
        self.setBlocking(False)
        self.progressBarFinished(processEvents=False)
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def _update_xy(self):
        # Update the plotted embedding coordinates
        self.graph.plot_widget.clear()
        coords = self.plotdata.embedding_coords
        radius = np.max(np.linalg.norm(coords, axis=1))
        self.plotdata.embedding_coords = coords / radius
        self.plot(show_anchors=(len(self.data.domain.attributes) < MAX_ANCHORS))

    def _new_plotdata(self):
        self.plotdata = namespace(
            validmask=None,
            embedding_coords=None,
            anchors=[],
            anchoritem=[],
            X=None,
            Y=None,
            indicators=[],
            hidecircle=None,
            data=None,
            items=[],
            topattrs=None,
            rand=None,
            selection=None,  # np.array
        )

    def _anchor_circle(self):
        # minimum visible anchor radius (radius)
        minradius = self.radius / 100 + 1e-5
        for item in chain(self.plotdata.anchoritem, self.plotdata.items):
            self.viewbox.removeItem(item)
        self.plotdata.anchoritem = []
        self.plotdata.items = []
        for anchor, var in zip(self.plotdata.anchors, self.data.domain.attributes):
            if True or np.linalg.norm(anchor) > minradius:
                axitem = AnchorItem(
                    line=QLineF(0, 0, *anchor), text=var.name,)
                axitem.setVisible(np.linalg.norm(anchor) > minradius)
                axitem.setPen(pg.mkPen((100, 100, 100)))
                axitem.setArrowVisible(True)
                self.plotdata.anchoritem.append(axitem)
                self.viewbox.addItem(axitem)

        hidecircle = QGraphicsEllipseItem()
        hidecircle.setRect(
            QRectF(-minradius, -minradius,
                   2 * minradius, 2 * minradius))

        _pen = QPen(Qt.lightGray, 1)
        _pen.setCosmetic(True)
        hidecircle.setPen(_pen)
        self.viewbox.addItem(hidecircle)
        self.plotdata.items.append(hidecircle)
        self.plotdata.hidecircle = hidecircle

    def update_colors(self):
        pass

    def sizeHint(self):
        return QSize(800, 500)

    def _clear(self):
        """
        Clear/reset the widget state
        """
        self._loop.cancel()
        self.data = None
        self._selection = None
        self._clear_plot()

    def _clear_plot(self):
        for item in chain(self.plotdata.anchoritem, self.plotdata.items):
            self.viewbox.removeItem(item)
        self.graph.plot_widget.clear()
        self._new_plotdata()

    def init_attr_values(self):
        domain = self.data and len(self.data) and self.data.domain or None
        for model in self.models:
            model.set_domain(domain)
        self.graph.attr_label = None
        self.graph.attr_size = None
        self.graph.attr_shape = None
        self.graph.attr_color = self.data.domain.class_var if domain else None

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self._clear()
        self.closeContext()
        if data is not None:
            if data and data.is_sparse():
                self.Warning.sparse_not_supported()
                data = None
            elif data.domain.class_var is None:
                self.Error.no_class_var()
                data = None
            elif data.domain.class_var.is_discrete and \
                            len(data.domain.class_var.values) < 2:
                self.Error.not_enough_class_vars()
                data = None
            if data and len(data.domain.attributes) > data.X.shape[0]:
                self.Error.features_exceeds_instances()
                data = None
        if data is not None:
            valid_instances_count = self._prepare_freeviz_data(data)
            if valid_instances_count > MAX_INSTANCES:
                self.Error.too_many_data_instances()
                data = None
            elif valid_instances_count == 0:
                self.Error.no_valid_data()
                data = None
        self.data = data
        self.init_attr_values()
        if data is not None:
            self.cb_class_density.setEnabled(data.domain.has_discrete_class)
            self.openContext(data)
            self.btn_start.setEnabled(True)
        else:
            self.btn_start.setEnabled(False)
            self._X = self._Y = None
            self.graph.new_data(None, None)

    @Inputs.data_subset
    def set_subset_data(self, subset):
        self.subset_data = subset
        self.plotdata.subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if all(v is not None for v in [self.data, self.subset_data]):
            dataids = self.data.ids.ravel()
            subsetids = np.unique(self.subset_data.ids)
            self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True)
        if self._X is not None:
            self.setup_plot(True)
        self.commit()

    def customEvent(self, event):
        if event.type() == OWFreeViz.ReplotRequest:
            self.__replot_requested = False
            self.setup_plot()
        else:
            super().customEvent(event)

    def _prepare_freeviz_data(self, data):
        X = data.X
        Y = data.Y
        mask = np.bitwise_or.reduce(np.isnan(X), axis=1)
        mask |= np.isnan(Y)
        validmask = ~mask
        X = X[validmask, :]
        Y = Y[validmask]

        if not len(X):
            self._X = None
            return 0

        if data.domain.class_var.is_discrete:
            Y = Y.astype(int)
        X = (X - np.mean(X, axis=0))
        span = np.ptp(X, axis=0)
        X[:, span > 0] /= span[span > 0].reshape(1, -1)
        self._X = X
        self._Y = Y
        self._validmask = validmask
        return len(X)

    def setup_plot(self, reset_view=True):
        assert not self._X is None

        self.graph.jitter_continuous = True
        self.__replot_requested = False

        X = self.plotdata.X = self._X
        self.plotdata.Y = self._Y
        self.plotdata.validmask = self._validmask
        self.plotdata.selection = self._selection if self._selection is not None else \
            np.zeros(len(self._validmask), dtype=np.uint8)
        anchors = self.plotdata.anchors
        if len(anchors) == 0:
            if self.initialization == self.Circular:
                anchors = FreeViz.init_radial(X.shape[1])
            else:
                anchors = FreeViz.init_random(X.shape[1], 2)

        EX = np.dot(X, anchors)
        c = np.zeros((X.shape[0], X.shape[1]))
        for i in range(X.shape[0]):
            c[i] = np.argsort((np.power(X[i] * anchors[:, 0], 2) +
                               np.power(X[i] * anchors[:, 1], 2)))[::-1]
        self.plotdata.topattrs = np.array(c, dtype=int)[:, :10]
        radius = np.max(np.linalg.norm(EX, axis=1))

        self.plotdata.anchors = anchors

        coords = (EX / radius)
        self.plotdata.embedding_coords = coords
        if reset_view:
            self.viewbox.setRange(RANGE)
            self.viewbox.setAspectLocked(True, 1)
        self.plot(reset_view=reset_view)

    def randomize_indices(self):
        X = self._X
        self.plotdata.rand = np.random.choice(len(X), MAX_POINTS, replace=False) \
            if len(X) > MAX_POINTS else None

    def manual_move_anchor(self, show_anchors=True):
        self.__replot_requested = False
        X = self.plotdata.X = self._X
        anchors = self.plotdata.anchors
        validmask = self.plotdata.validmask
        EX = np.dot(X, anchors)
        data_x = self.data.X[validmask]
        data_y = self.data.Y[validmask]
        radius = np.max(np.linalg.norm(EX, axis=1))
        if self.plotdata.rand is not None:
            rand = self.plotdata.rand
            EX = EX[rand]
            data_x = data_x[rand]
            data_y = data_y[rand]
            selection = self.plotdata.selection[validmask]
            selection = selection[rand]
        else:
            selection = self.plotdata.selection[validmask]
        coords = (EX / radius)

        if show_anchors:
            self._anchor_circle()
        attributes = () + self.data.domain.attributes + (self.variable_x, self.variable_y)
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars)
        data = Table.from_numpy(domain, X=np.hstack((data_x, coords)),
                                Y=data_y)
        self.graph.new_data(data, None)
        self.graph.selection = selection
        self.graph.update_data(self.variable_x, self.variable_y, reset_view=False)

    def plot(self, reset_view=False, show_anchors=True):
        if show_anchors:
            self._anchor_circle()
        attributes = () + self.data.domain.attributes + (self.variable_x, self.variable_y)
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars,
                        metas=self.data.domain.metas)
        mask = self.plotdata.validmask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[mask] = self.plotdata.embedding_coords
        data = Table.from_numpy(domain, X=np.hstack((self.data.X, array)),
                                Y=self.data.Y,
                                metas=self.data.metas)
        subset_data = data[self._subset_mask & mask]\
            if self._subset_mask is not None and len(self._subset_mask) else None
        self.plotdata.data = data
        self.graph.new_data(data[mask], subset_data)
        if self.plotdata.selection is not None:
            self.graph.selection = self.plotdata.selection[self.plotdata.validmask]
        self.graph.update_data(self.variable_x, self.variable_y, reset_view=reset_view)

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph()

    def _update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        assert not self.graph.data is None
        self.graph.update_data(self.variable_x, self.variable_y, reset_view)

    def update_density(self):
        if self.graph.data is None:
            return
        self._update_graph(reset_view=False)

    def selection_changed(self):
        if self.graph.selection is not None:
            pd = self.plotdata
            pd.selection[pd.validmask] = self.graph.selection
            self._selection = pd.selection
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        selected = annotated = components = None
        graph = self.graph
        if self.data is not None and self.plotdata.validmask is not None:
            name = self.data.name
            metas = () + self.data.domain.metas + (self.variable_x, self.variable_y)
            domain = Domain(attributes=self.data.domain.attributes,
                            class_vars=self.data.domain.class_vars,
                            metas=metas)
            data = self.plotdata.data.transform(domain)
            validmask = self.plotdata.validmask
            mask = np.array(validmask, dtype=int)
            mask[mask == 1] = graph.selection if graph.selection is not None \
                else [False * len(mask)]
            selection = np.array([], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            comp_domain = Domain(
                self.data.domain.attributes,
                metas=[StringVariable(name='component')])

            metas = np.array([["FreeViz 1"], ["FreeViz 2"]])
            components = Table.from_numpy(
                comp_domain,
                X=self.plotdata.anchors.T,
                metas=metas)

            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert((
            ("Color", name(self.graph.attr_color)),
            ("Label", name(self.graph.attr_label)),
            ("Shape", name(self.graph.attr_shape)),
            ("Size", name(self.graph.attr_size)),
            ("Jittering", self.graph.jitter_size != 0 and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)
示例#7
0
class OWMDS(widget.OWWidget):
    name = "MDS"
    description = "Multidimensional scaling"
    icon = "icons/MDS.svg"

    inputs = ({
        "name": "Data",
        "type": Orange.data.Table,
        "handler": "set_data"
    }, {
        "name": "Distances",
        "type": Orange.misc.DistMatrix,
        "handler": "set_disimilarity"
    })

    outputs = ({"name": "Data", "type": Orange.data.Table}, )

    #: Initialization type
    PCA, Random = 0, 1

    settingsHandler = settings.DomainContextHandler()

    max_iter = settings.Setting(300)
    eps = settings.Setting(1e-3)
    initialization = settings.Setting(PCA)
    n_init = settings.Setting(1)

    output_embeding_role = settings.Setting(1)
    autocommit = settings.Setting(True)

    color_var = settings.ContextSetting(0, not_variable=True)
    shape_var = settings.ContextSetting(0, not_variable=True)
    size_var = settings.ContextSetting(0, not_variable=True)

    # output embeding role.
    NoRole, AttrRole, MetaRole = 0, 1, 2

    def __init__(self, parent=None):
        super().__init__(parent)
        self.matrix = None
        self.data = None

        self._pen_data = None
        self._shape_data = None
        self._size_data = None

        self._invalidated = False
        self._effective_matrix = None
        self._output_changed = False

        box = gui.widgetBox(self.controlArea, "MDS Optimization")
        form = QtGui.QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QtGui.QFormLayout.AllNonFixedFieldsGrow,
        )

        form.addRow("Max iterations:",
                    gui.spin(box, self, "max_iter", 10, 10**4, step=1))

        #         form.addRow("Eps:",
        #                     gui.spin(box, self, "eps", 1e-9, 1e-3, step=1e-9,
        #                              spinType=float))

        form.addRow(
            "Initialization",
            gui.comboBox(box,
                         self,
                         "initialization",
                         items=["PCA (Torgerson)", "Random"]))

        #         form.addRow("N Restarts:",
        #                     gui.spin(box, self, "n_init", 1, 10, step=1))

        box.layout().addLayout(form)
        gui.button(box, self, "Apply", callback=self._invalidate_embeding)

        box = gui.widgetBox(self.controlArea, "Graph")
        self.colorvar_model = itemmodels.VariableListModel()
        cb = gui.comboBox(box,
                          self,
                          "color_var",
                          box="Color",
                          callback=self._on_color_var_changed)
        cb.setModel(self.colorvar_model)
        cb.box.setFlat(True)

        self.shapevar_model = itemmodels.VariableListModel()
        cb = gui.comboBox(box,
                          self,
                          "shape_var",
                          box="Shape",
                          callback=self._on_shape_var_changed)
        cb.setModel(self.shapevar_model)
        cb.box.setFlat(True)

        self.sizevar_model = itemmodels.VariableListModel()
        cb = gui.comboBox(box,
                          self,
                          "size_var",
                          "Size",
                          callback=self._on_size_var_changed)
        cb.setModel(self.sizevar_model)
        cb.box.setFlat(True)

        gui.rubber(self.controlArea)
        box = gui.widgetBox(self.controlArea, "Output")
        cb = gui.comboBox(box,
                          self,
                          "output_embeding_role",
                          box="Append coordinates",
                          items=["Do not append", "As attributes", "As metas"],
                          callback=self._invalidate_output)
        cb.box.setFlat(True)

        cb = gui.checkBox(box, self, "autocommit", "Auto commit")
        b = gui.button(box, self, "Commit", callback=self.commit, default=True)
        gui.setStopper(self, b, cb, "_output_changed", callback=self.commit)

        self.plot = pg.PlotWidget(background="w")
        self.mainArea.layout().addWidget(self.plot)

    def set_data(self, data):
        self.closeContext()
        self._clear()
        self.data = data
        if data is not None:
            self._initialize(data)
            self.openContext(data)

        if self.matrix is None:
            self._effective_matrix = None
            self._invalidated = True

    def set_disimilarity(self, matrix):
        self.matrix = matrix
        self._effective_matrix = matrix
        self._invalidated = True

    def _clear(self):
        self._pen_data = None
        self._shape_data = None
        self._size_data = None
        self.colorvar_model[:] = ["Same color"]
        self.shapevar_model[:] = ["Same shape"]
        self.sizevar_model[:] = ["Same size"]

        self.color_var = 0
        self.shape_var = 0
        self.size_var = 0

    def _initialize(self, data):
        # initialize the graph state from data
        domain = data.domain
        all_vars = list(domain.variables + domain.metas)
        disc_vars = list(filter(is_discrete, all_vars))
        cont_vars = list(filter(is_continuous, all_vars))

        def set_separator(model, index):
            index = model.index(index, 0)
            model.setData(index, "separator", Qt.AccessibleDescriptionRole)
            model.setData(index, Qt.NoItemFlags, role="flags")

        self.colorvar_model[:] = ["Same color", ""] + all_vars
        set_separator(self.colorvar_model, 1)

        self.shapevar_model[:] = ["Same shape", ""] + disc_vars
        set_separator(self.shapevar_model, 1)

        self.sizevar_model[:] = ["Same size", ""] + cont_vars
        set_separator(self.sizevar_model, 1)

        if domain.class_var is not None:
            self.color_var = list(self.colorvar_model).index(domain.class_var)

    def apply(self):
        if self.data is None and self.matrix is None:
            self.embeding = None
            self._update_plot()
            return

        if self._effective_matrix is None:
            if self.matrix is not None:
                self._effective_matrix = self.matrix
            elif self.data is not None:
                self._effective_matrix = Orange.distance.Euclidean()(self.data)

        X = self._effective_matrix.X

        if self.initialization == OWMDS.PCA:
            init = torgerson(X, n_components=2)
            n_init = 1
        else:
            init = None
            n_init = self.n_init

        dissim = "precomputed"

        mds = sklearn.manifold.MDS(dissimilarity=dissim,
                                   n_components=2,
                                   n_init=n_init,
                                   max_iter=self.max_iter)
        embeding = mds.fit_transform(X, init=init)
        self.embeding = embeding
        self.stress = mds.stress_

    def handleNewSignals(self):
        if self._invalidated:
            self._invalidated = False
            self.apply()

        self._update_plot()
        self.commit()

    def _invalidate_embeding(self):
        self.apply()
        self._update_plot()
        self._invalidate_output()

    def _invalidate_output(self):
        if self.autocommit:
            self.commit()
        else:
            self._output_changed = True

    def _on_color_var_changed(self):
        self._pen_data = None
        self._update_plot()

    def _on_shape_var_changed(self):
        self._shape_data = None
        self._update_plot()

    def _on_size_var_changed(self):
        self._size_data = None
        self._update_plot()

    def _update_plot(self):
        self.plot.clear()
        if self.embeding is not None:
            self._setup_plot()

    def _setup_plot(self):
        have_data = self.data is not None

        if self._pen_data is None:
            if have_data and self.color_var > 0:
                color_var = self.colorvar_model[self.color_var]
                if is_discrete(color_var):
                    palette = colorpalette.ColorPaletteGenerator(
                        len(color_var.values))
                else:
                    palette = None

                color_data = colors(self.data, color_var, palette)
                pen_data = [
                    QtGui.QPen(QtGui.QColor(r, g, b)) for r, g, b in color_data
                ]
            else:
                pen_data = QtGui.QPen(Qt.black)
            self._pen_data = pen_data

        if self._shape_data is None:
            if have_data and self.shape_var > 0:
                Symbols = pg.graphicsItems.ScatterPlotItem.Symbols
                symbols = numpy.array(list(Symbols.keys()))

                shape_var = self.shapevar_model[self.shape_var]
                data = numpy.array(self.data[:, shape_var]).ravel()
                data = data % (len(Symbols) - 1)
                data[numpy.isnan(data)] = len(Symbols) - 1
                shape_data = symbols[data.astype(int)]
            else:
                shape_data = "o"
            self._shape_data = shape_data

        if self._size_data is None:
            MinPointSize = 1
            point_size = 8 + MinPointSize
            if have_data and self.size_var > 0:
                size_var = self.sizevar_model[self.size_var]
                size_data = numpy.array(self.data[:, size_var]).ravel()
                dmin, dmax = numpy.nanmin(size_data), numpy.nanmax(size_data)
                if dmax - dmin > 0:
                    size_data = (size_data - dmin) / (dmax - dmin)

                size_data = MinPointSize + size_data * point_size
            else:
                size_data = point_size

        item = pg.ScatterPlotItem(x=self.embeding[:, 0],
                                  y=self.embeding[:, 1],
                                  pen=self._pen_data,
                                  symbol=self._shape_data,
                                  brush=QtGui.QBrush(Qt.transparent),
                                  size=size_data,
                                  antialias=True)
        # plot(x, y, colors=plot.colors(data[:, color_var]),
        #      point_size=data[:, size_var],
        #      symbol=data[:, symbol_var])

        self.plot.addItem(item)

    def commit(self):
        if self.embeding is not None:
            output = embeding = Orange.data.Table.from_numpy(
                Orange.data.Domain([
                    Orange.data.ContinuousVariable("X"),
                    Orange.data.ContinuousVariable("Y")
                ]), self.embeding)
        else:
            output = embeding = None

        if self.embeding is not None and self.data is not None:
            X, Y, M = self.data.X, self.data.Y, self.data.metas
            domain = self.data.domain
            attrs = domain.attributes
            class_vars = domain.class_vars
            metas = domain.metas

            if self.output_embeding_role == OWMDS.NoRole:
                pass
            elif self.output_embeding_role == OWMDS.AttrRole:
                attrs = attrs + embeding.domain.attributes
                X = numpy.c_[X, embeding.X]
            elif self.output_embeding_role == OWMDS.MetaRole:
                metas = metas + embeding.domain.attributes
                M = numpy.c_[M, embeding.X]

            domain = Orange.data.Domain(attrs, class_vars, metas)
            output = Orange.data.Table.from_numpy(domain, X, Y, M)

        self.send("Data", output)
        self._output_changed = False

    def onDeleteWidget(self):
        self.plot.clear()
        super().onDeleteWidget()
示例#8
0
class OWSparkFillNa(widget.OWWidget):
    priority = 4
    name = "FillNa"
    description = "Replace null values"
    icon = "../icons/Impute.svg"

    inputs = [("DataFrame", pyspark.sql.DataFrame, "get_input", widget.Default)
              ]
    outputs = [("DataFrame", pyspark.sql.DataFrame, widget.Default)]
    settingsHandler = settings.DomainContextHandler()

    in_df = None
    want_main_area = False
    resizing_enabled = True
    saved_gui_params = Setting(OrderedDict())

    def __init__(self):
        super().__init__()

        self.main_box = gui.widgetBox(self.controlArea,
                                      orientation='horizontal',
                                      addSpace=True)
        self.box = gui.widgetBox(self.main_box, 'Parameters:', addSpace=True)
        self.help_box = gui.widgetBox(self.main_box,
                                      'Documentation',
                                      addSpace=True)

        self.gui_parameters = OrderedDict()

        # Create method label doc.
        self.method_info_label = QtGui.QTextEdit('', self.help_box)
        self.method_info_label.setAcceptRichText(True)
        self.method_info_label.setReadOnly(True)
        self.method_info_label.autoFormatting()
        self.method_info_label.setText(get_dataframe_function_info('fillna'))
        self.help_box.layout().addWidget(self.method_info_label)

        # Create parameters Box.
        self.gui_parameters = OrderedDict()
        default_value = self.saved_gui_params.get('value', '0')
        self.gui_parameters['value'] = GuiParam(parent_widget=self.box,
                                                label='value',
                                                default_value=default_value)
        default_value = self.saved_gui_params.get('subset', 'None')
        self.gui_parameters['subset'] = GuiParam(parent_widget=self.box,
                                                 label='subset',
                                                 default_value='None')

        self.action_box = gui.widgetBox(self.box)
        # Action Button
        self.create_sc_btn = gui.button(self.action_box,
                                        self,
                                        label='Apply',
                                        callback=self.apply)

    def get_input(self, obj=None):
        self.in_df = obj

    def apply(self):
        if self.in_df:
            value = self.gui_parameters['value'].get_usable_value()
            subset = self.gui_parameters['subset'].get_usable_value()
            self.send("DataFrame", self.in_df.fillna(value, subset))
            self.update_saved_gui_parameters()
            self.hide()

    def update_saved_gui_parameters(self):
        for k in self.gui_parameters:
            self.saved_gui_params[k] = self.gui_parameters[k].get_value()
示例#9
0
class OWRadviz(widget.OWWidget):
    name = "Radviz"
    description = "Radviz"
    icon = "icons/Radviz.svg"
    priority = 240
    keywords = ["viz"]

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    settings_version = 1
    settingsHandler = settings.DomainContextHandler()

    variable_state = settings.ContextSetting({})

    auto_commit = settings.Setting(True)
    graph = settings.SettingProvider(OWRadvizGraph)
    vizrank = settings.SettingProvider(RadvizVizRank)

    jitter_sizes = [0, 0.1, 0.5, 1.0, 2.0]

    ReplotRequest = QEvent.registerEventType()

    graph_name = "graph.plot_widget.plotItem"

    class Information(widget.OWWidget.Information):
        sql_sampled_data = widget.Msg("Data has been sampled")

    class Warning(widget.OWWidget.Warning):
        no_features = widget.Msg("At least 2 features have to be chosen")

    class Error(widget.OWWidget.Error):
        sparse_data = widget.Msg("Sparse data is not supported")
        no_features = widget.Msg(
            "At least 3 numeric or categorical variables are required")
        no_instances = widget.Msg("At least 2 data instances are required")

    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self._subset_mask = None
        self._selection = None  # np.array
        self.__replot_requested = False
        self._new_plotdata()

        self.variable_x = ContinuousVariable("radviz-x")
        self.variable_y = ContinuousVariable("radviz-y")

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWRadvizGraph(self,
                                   box,
                                   "Plot",
                                   view_box=RadvizInteractiveViewBox)
        self.graph.hide_axes()

        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        SIZE_POLICY = (QSizePolicy.Minimum, QSizePolicy.Maximum)

        self.variables_selection = VariablesSelection()
        self.model_selected = VariableListModel(enable_dnd=True)
        self.model_other = VariableListModel(enable_dnd=True)
        self.variables_selection(self, self.model_selected, self.model_other)

        self.vizrank, self.btn_vizrank = RadvizVizRank.add_vizrank(
            self.controlArea, self, "Suggest features", self.vizrank_set_attrs)
        self.btn_vizrank.setSizePolicy(*SIZE_POLICY)
        self.variables_selection.add_remove.layout().addWidget(
            self.btn_vizrank)

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        pp_box = g.point_properties_box(self.controlArea)
        pp_box.setSizePolicy(*SIZE_POLICY)
        self.models = g.points_models

        box = gui.vBox(self.controlArea, "Plot Properties")
        box.setSizePolicy(*SIZE_POLICY)
        g.add_widget(g.JitterSizeSlider, box)

        g.add_widgets([g.ShowLegend, g.ClassDensity, g.LabelOnlySelected], box)

        zoom_select = self.graph.box_zoom_select(self.controlArea)
        zoom_select.setSizePolicy(*SIZE_POLICY)

        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea,
                        self,
                        "auto_commit",
                        "Send Selection",
                        auto_label="Send Automatically")

        self.graph.zoom_actions(self)

        self._circle = QGraphicsEllipseItem()
        self._circle.setRect(QRectF(-1., -1., 2., 2.))
        self._circle.setPen(pg.mkPen(QColor(0, 0, 0), width=2))

    def resizeEvent(self, event):
        self._update_points_labels()

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def vizrank_set_attrs(self, attrs):
        if not attrs:
            return
        self.variables_selection.display_none()
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [v for v in self.model_other if v not in attrs]

    def _new_plotdata(self):
        self.plotdata = namespace(
            valid_mask=None,
            embedding_coords=None,
            points=None,
            arcarrows=[],
            point_labels=[],
            rand=None,
            data=None,
        )

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def sizeHint(self):
        return QSize(800, 500)

    def clear(self):
        """
        Clear/reset the widget state
        """
        self.data = None
        self.model_selected.clear()
        self.model_other.clear()
        self._clear_plot()

    def _clear_plot(self):
        self._new_plotdata()
        self.graph.plot_widget.clear()

    def invalidate_plot(self):
        """
        Schedule a delayed replot.
        """
        if not self.__replot_requested:
            self.__replot_requested = True
            QApplication.postEvent(self, QEvent(self.ReplotRequest),
                                   Qt.LowEventPriority - 10)

    def init_attr_values(self):
        self.graph.set_domain(self.data)

    def _vizrank_color_change(self):
        attr_color = self.graph.attr_color
        is_enabled = self.data is not None and not self.data.is_sparse() and \
                     (len(self.model_other) + len(self.model_selected)) > 3 and len(self.data) > 1
        self.btn_vizrank.setEnabled(
            is_enabled and attr_color is not None and not np.isnan(
                self.data.get_column_view(attr_color)[0].astype(float)).all())
        self.vizrank.initialize()

    @Inputs.data
    def set_data(self, data):
        """
        Set the input dataset and check if data is valid.

        Args:
            data (Orange.data.table): data instances
        """
        def sql(data):
            self.Information.sql_sampled_data.clear()
            if isinstance(data, SqlTable):
                if data.approx_len() < 4000:
                    data = Table(data)
                else:
                    self.Information.sql_sampled_data()
                    data_sample = data.sample_time(1, no_cache=True)
                    data_sample.download_data(2000, partial=True)
                    data = Table(data_sample)
            return data

        def settings(data):
            # get the default encoded state, replacing the position with Inf
            state = VariablesSelection.encode_var_state(
                [list(self.model_selected),
                 list(self.model_other)])
            state = {
                key: (source_ind, np.inf)
                for key, (source_ind, _) in state.items()
            }

            self.openContext(data.domain)
            selected_keys = [
                key for key, (sind, _) in self.variable_state.items()
                if sind == 0
            ]

            if set(selected_keys).issubset(set(state.keys())):
                pass

            # update the defaults state (the encoded state must contain
            # all variables in the input domain)
            state.update(self.variable_state)
            # ... and restore it with saved positions taking precedence over
            # the defaults
            selected, other = VariablesSelection.decode_var_state(
                state, [list(self.model_selected),
                        list(self.model_other)])
            return selected, other

        def is_sparse(data):
            if data.is_sparse():
                self.Error.sparse_data()
                data = None
            return data

        def are_features(data):
            domain = data.domain
            vars = [
                var for var in chain(domain.class_vars, domain.metas,
                                     domain.attributes) if var.is_primitive()
            ]
            if len(vars) < 3:
                self.Error.no_features()
                data = None
            return data

        def are_instances(data):
            if len(data) < 2:
                self.Error.no_instances()
                data = None
            return data

        self.clear_messages()
        self.btn_vizrank.setEnabled(False)
        self.closeContext()
        self.clear()
        self.information()
        self.Error.clear()
        for f in [sql, is_sparse, are_features, are_instances]:
            if data is None:
                break
            data = f(data)

        if data is not None:
            self.data = data
            self.init_attr_values()
            domain = data.domain
            vars = [
                v for v in chain(domain.metas, domain.attributes)
                if v.is_primitive()
            ]
            self.model_selected[:] = vars[:5]
            self.model_other[:] = vars[5:] + list(domain.class_vars)
            self.model_selected[:], self.model_other[:] = settings(data)
            self._selection = np.zeros(len(data), dtype=np.uint8)
            self.invalidate_plot()
        else:
            self.data = None

    @Inputs.data_subset
    def set_subset_data(self, subset):
        """
        Set the supplementary input subset dataset.

        Args:
            subset (Orange.data.table): subset of data instances
        """
        self.subset_data = subset
        self._subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if self.data is not None:
            self._clear_plot()
            if self.subset_data is not None and self._subset_mask is None:
                dataids = self.data.ids.ravel()
                subsetids = np.unique(self.subset_data.ids)
                self._subset_mask = np.in1d(dataids,
                                            subsetids,
                                            assume_unique=True)
            self.setup_plot(reset_view=True)
            self.cb_class_density.setEnabled(self.graph.can_draw_density())
        else:
            self.init_attr_values()
            self.graph.new_data(None)
        self._vizrank_color_change()
        self.commit()

    def customEvent(self, event):
        if event.type() == OWRadviz.ReplotRequest:
            self.__replot_requested = False
            self._clear_plot()
            self.setup_plot(reset_view=True)
        else:
            super().customEvent(event)

    def closeContext(self):
        self.variable_state = VariablesSelection.encode_var_state(
            [list(self.model_selected),
             list(self.model_other)])
        super().closeContext()

    def prepare_radviz_data(self, variables):
        ec, points, valid_mask = radviz(self.data, variables,
                                        self.plotdata.points)
        self.plotdata.embedding_coords = ec
        self.plotdata.points = points
        self.plotdata.valid_mask = valid_mask

    def setup_plot(self, reset_view=True):
        if self.data is None:
            return
        self.graph.jitter_continuous = True
        self.__replot_requested = False

        variables = list(self.model_selected)
        if len(variables) < 2:
            self.Warning.no_features()
            self.graph.new_data(None)
            return

        self.Warning.clear()
        self.prepare_radviz_data(variables)

        if self.plotdata.embedding_coords is None:
            return

        domain = self.data.domain
        new_metas = domain.metas + (self.variable_x, self.variable_y)
        domain = Domain(attributes=domain.attributes,
                        class_vars=domain.class_vars,
                        metas=new_metas)
        mask = self.plotdata.valid_mask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[mask] = self.plotdata.embedding_coords
        data = self.data.transform(domain)
        data[:, self.variable_x] = array[:, 0].reshape(-1, 1)
        data[:, self.variable_y] = array[:, 1].reshape(-1, 1)
        subset_data = data[self._subset_mask & mask]\
            if self._subset_mask is not None and len(self._subset_mask) else None
        self.plotdata.data = data
        self.graph.new_data(data[mask], subset_data)
        if self._selection is not None:
            self.graph.selection = self._selection[self.plotdata.valid_mask]
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=reset_view)
        self.graph.plot_widget.addItem(self._circle)
        self.graph.scatterplot_points = ScatterPlotItem(
            x=self.plotdata.points[:, 0], y=self.plotdata.points[:, 1])
        self._update_points_labels()
        self.graph.plot_widget.addItem(self.graph.scatterplot_points)

    def randomize_indices(self):
        ec = self.plotdata.embedding_coords
        self.plotdata.rand = np.random.choice(len(ec), MAX_POINTS, replace=False) \
            if len(ec) > MAX_POINTS else None

    def manual_move(self):
        self.__replot_requested = False

        if self.plotdata.rand is not None:
            rand = self.plotdata.rand
            valid_mask = self.plotdata.valid_mask
            data = self.data[valid_mask]
            selection = self._selection[valid_mask]
            selection = selection[rand]
            ec, _, valid_mask = radviz(data, list(self.model_selected),
                                       self.plotdata.points)
            assert sum(valid_mask) == len(data)
            data = data[rand]
            ec = ec[rand]
            data_x = data.X
            data_y = data.Y
            data_metas = data.metas
        else:
            self.prepare_radviz_data(list(self.model_selected))
            ec = self.plotdata.embedding_coords
            valid_mask = self.plotdata.valid_mask
            data_x = self.data.X[valid_mask]
            data_y = self.data.Y[valid_mask]
            data_metas = self.data.metas[valid_mask]
            selection = self._selection[valid_mask]

        attributes = (self.variable_x,
                      self.variable_y) + self.data.domain.attributes
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars,
                        metas=self.data.domain.metas)
        data = Table.from_numpy(domain,
                                X=np.hstack((ec, data_x)),
                                Y=data_y,
                                metas=data_metas)
        self.graph.new_data(data, None)
        self.graph.selection = selection
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=True)
        self.graph.plot_widget.addItem(self._circle)
        self.graph.scatterplot_points = ScatterPlotItem(
            x=self.plotdata.points[:, 0], y=self.plotdata.points[:, 1])
        self._update_points_labels()
        self.graph.plot_widget.addItem(self.graph.scatterplot_points)

    def _update_points_labels(self):
        if self.plotdata.points is None:
            return
        for point_label in self.plotdata.point_labels:
            self.graph.plot_widget.removeItem(point_label)
        self.plotdata.point_labels = []
        sx, sy = self.graph.view_box.viewPixelSize()

        for row in self.plotdata.points:
            ti = TextItem()
            metrics = QFontMetrics(ti.textItem.font())
            text_width = ((RANGE.width()) / 2. - np.abs(row[0])) / sx
            name = row[2].name
            ti.setText(name)
            ti.setTextWidth(text_width)
            ti.setColor(QColor(0, 0, 0))
            br = ti.boundingRect()
            width = metrics.width(
                name) if metrics.width(name) < br.width() else br.width()
            width = sx * (width + 5)
            height = sy * br.height()
            ti.setPos(row[0] - (row[0] < 0) * width,
                      row[1] + (row[1] > 0) * height)
            self.plotdata.point_labels.append(ti)
            self.graph.plot_widget.addItem(ti)

    def _update_jitter(self):
        self.invalidate_plot()

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph()

    def _update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=reset_view)

    def update_density(self):
        self._update_graph(reset_view=True)

    def selection_changed(self):
        if self.graph.selection is not None:
            self._selection[self.plotdata.valid_mask] = self.graph.selection
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        selected = annotated = components = None
        graph = self.graph
        if self.plotdata.data is not None:
            name = self.data.name
            data = self.plotdata.data
            mask = self.plotdata.valid_mask.astype(int)
            mask[mask == 1] = graph.selection if graph.selection is not None \
                else [False * len(mask)]
            selection = np.array(
                [], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            comp_domain = Domain(self.plotdata.points[:, 2],
                                 metas=[StringVariable(name='component')])

            metas = np.array([["RX"], ["RY"], ["angle"]])
            angle = np.arctan2(
                np.array(self.plotdata.points[:, 1].T, dtype=float),
                np.array(self.plotdata.points[:, 0].T, dtype=float))
            components = Table.from_numpy(comp_domain,
                                          X=np.row_stack(
                                              (self.plotdata.points[:, :2].T,
                                               angle)),
                                          metas=metas)
            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)
class OWGeneInfo(widget.OWWidget):
    name = "Gene Info"
    description = "Displays gene information from NCBI and other sources."
    icon = "../widgets/icons/OWGeneInfo.svg"
    priority = 5

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_genes = Output("Selected Genes", Orange.data.Table)
        data = Output("Data", Orange.data.Table)

    settingsHandler = settings.DomainContextHandler()

    organism_index = settings.ContextSetting(0)
    taxid = settings.ContextSetting("9606")

    gene_attr = settings.ContextSetting(0)

    auto_commit = settings.Setting(False)
    search_string = settings.Setting("")

    useAttr = settings.ContextSetting(False)
    useAltSource = settings.ContextSetting(False)

    def __init__(
        self,
        parent=None,
    ):
        super().__init__(self, parent)

        self.selectionChangedFlag = False

        self.__initialized = False
        self.initfuture = None
        self.itemsfuture = None

        self.map_input_to_ensembl = None
        self.infoLabel = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info", addSpace=True),
            "Initializing\n")

        self.organisms = None
        self.organismBox = gui.widgetBox(self.controlArea,
                                         "Organism",
                                         addSpace=True)

        self.organismComboBox = gui.comboBox(
            self.organismBox,
            self,
            "organism_index",
            callback=self._onSelectedOrganismChanged)

        box = gui.widgetBox(self.controlArea, "Gene names", addSpace=True)
        self.geneAttrComboBox = gui.comboBox(box,
                                             self,
                                             "gene_attr",
                                             "Gene attribute",
                                             callback=self.updateInfoItems)
        self.geneAttrComboBox.setEnabled(not self.useAttr)

        self.geneAttrCheckbox = gui.checkBox(box,
                                             self,
                                             "useAttr",
                                             "Use column names",
                                             callback=self.updateInfoItems)
        self.geneAttrCheckbox.toggled[bool].connect(
            self.geneAttrComboBox.setDisabled)

        gui.auto_commit(self.controlArea, self, "auto_commit", "Commit")

        gui.rubber(self.controlArea)

        gui.lineEdit(self.mainArea,
                     self,
                     "search_string",
                     "Filter",
                     callbackOnType=True,
                     callback=self.searchUpdate)

        self.treeWidget = QTreeView(self.mainArea)

        self.treeWidget.setAlternatingRowColors(True)
        self.treeWidget.setSortingEnabled(True)
        self.treeWidget.setSelectionMode(QTreeView.ExtendedSelection)
        self.treeWidget.setUniformRowHeights(True)
        self.treeWidget.setRootIsDecorated(False)

        self.treeWidget.setItemDelegateForColumn(
            HEADER_SCHEMA['NCBI ID'],
            gui.LinkStyledItemDelegate(self.treeWidget))
        self.treeWidget.setItemDelegateForColumn(
            HEADER_SCHEMA['Ensembl ID'],
            gui.LinkStyledItemDelegate(self.treeWidget))

        self.treeWidget.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.treeWidget)

        box = gui.widgetBox(self.mainArea, "", orientation="horizontal")
        gui.button(box, self, "Select Filtered", callback=self.selectFiltered)
        gui.button(box,
                   self,
                   "Clear Selection",
                   callback=self.treeWidget.clearSelection)

        self.geneinfo = []
        self.cells = []
        self.row2geneinfo = {}
        self.data = None

        # : (# input genes, # matches genes)
        self.matchedInfo = 0, 0

        self.setBlocking(True)
        self.executor = ThreadExecutor(self)

        self.progressBarInit()

        task = Task(
            function=partial(taxonomy.ensure_downloaded,
                             callback=methodinvoke(self, "advance", ())))

        task.resultReady.connect(self.initialize)
        task.exceptionReady.connect(self._onInitializeError)

        self.initfuture = self.executor.submit(task)

    def sizeHint(self):
        return QSize(1024, 720)

    @Slot()
    def advance(self):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(self.progressBarValue + 1, processEvents=None)

    def _get_available_organisms(self):
        available_organism = sorted([(tax_id, taxonomy.name(tax_id))
                                     for tax_id in taxonomy.common_taxids()],
                                    key=lambda x: x[1])

        self.organisms = [tax_id[0] for tax_id in available_organism]

        self.organismComboBox.addItems(
            [tax_id[1] for tax_id in available_organism])

    def initialize(self):
        if self.__initialized:
            # Already initialized
            return
        self.__initialized = True

        self._get_available_organisms()
        self.organism_index = self.organisms.index(taxonomy.DEFAULT_ORGANISM)
        self.taxid = self.organisms[self.organism_index]

        self.infoLabel.setText("No data on input\n")
        self.initfuture = None

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)

    def _onInitializeError(self, exc):
        sys.excepthook(type(exc), exc, None)
        self.error(0, "Could not download the necessary files.")

    def _onSelectedOrganismChanged(self):
        assert 0 <= self.organism_index <= len(self.organisms)
        self.taxid = self.organisms[self.organism_index]

        if self.data is not None:
            self.updateInfoItems()

    @Inputs.data
    def setData(self, data=None):
        if not self.__initialized:
            self.initfuture.result()
            self.initialize()

        if self.itemsfuture is not None:
            raise Exception("Already processing")

        self.data = data

        if data is not None:
            self.geneAttrComboBox.clear()
            self.attributes = [
                attr for attr in data.domain.variables + data.domain.metas
                if isinstance(attr, (Orange.data.StringVariable,
                                     Orange.data.DiscreteVariable))
            ]

            for var in self.attributes:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            self.taxid = str(self.data.attributes.get(TAX_ID, ''))
            self.useAttr = self.data.attributes.get(GENE_AS_ATTRIBUTE_NAME,
                                                    self.useAttr)

            self.gene_attr = min(self.gene_attr, len(self.attributes) - 1)

            if self.taxid in self.organisms:
                self.organism_index = self.organisms.index(self.taxid)

            self.updateInfoItems()
        else:
            self.clear()

    def updateInfoItems(self):
        self.warning(0)
        if self.data is None:
            return

        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [
                str(ex[attr]) for ex in self.data if not math.isnan(ex[attr])
            ]
        else:
            genes = []
        if not genes:
            self.warning(0, "Could not extract genes from input dataset.")

        self.warning(1)
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        source_name, info_getter = ("NCBI Info", ncbi_info)

        self.error(0)

        self.progressBarInit()
        self.setBlocking(True)
        self.setEnabled(False)
        self.infoLabel.setText("Retrieving info records.\n")

        self.genes = genes

        task = Task(function=partial(
            info_getter, org, genes, advance=methodinvoke(self, "advance", (
            ))))
        self.itemsfuture = self.executor.submit(task)
        task.finished.connect(self._onItemsCompleted)

    def _onItemsCompleted(self):
        self.setBlocking(False)
        self.progressBarFinished()
        self.setEnabled(True)

        try:
            self.map_input_to_ensembl, geneinfo = self.itemsfuture.result()
        finally:
            self.itemsfuture = None

        self.geneinfo = geneinfo
        self.cells = cells = []
        self.row2geneinfo = {}

        for i, (input_name, gi) in enumerate(geneinfo):
            if gi:
                row = []
                for item in gi:
                    row.append(item)

                # parse synonyms
                row[HEADER_SCHEMA['Synonyms']] = ','.join(
                    row[HEADER_SCHEMA['Synonyms']])
                cells.append(row)
                self.row2geneinfo[len(cells) - 1] = i

        model = TreeModel(cells, list(HEADER_SCHEMA.keys()), None)

        proxyModel = QSortFilterProxyModel(self)
        proxyModel.setSourceModel(model)
        self.treeWidget.setModel(proxyModel)
        self.treeWidget.selectionModel().selectionChanged.connect(self.commit)

        for i in range(len(HEADER_SCHEMA)):
            self.treeWidget.resizeColumnToContents(i)
            self.treeWidget.setColumnWidth(
                i, min(self.treeWidget.columnWidth(i), 200))

        self.infoLabel.setText("%i genes\n%i matched NCBI's IDs" %
                               (len(self.genes), len(cells)))
        self.matchedInfo = len(self.genes), len(cells)

        if self.useAttr:
            new_data = self.data.from_table(self.data.domain, self.data)

            for gene_var in new_data.domain.attributes:
                gene_var.attributes['Ensembl ID'] = str(
                    self.map_input_to_ensembl[gene_var.name])

            self.Outputs.data.send(new_data)

        elif self.attributes:
            ensembl_ids = []
            for gene_name in self.data.get_column_view(
                    self.attributes[self.gene_attr])[0]:
                if gene_name and gene_name in self.map_input_to_ensembl:
                    ensembl_ids.append(self.map_input_to_ensembl[gene_name])
                else:
                    ensembl_ids.append('')

            data_with_ensembl = append_columns(
                self.data,
                metas=[(Orange.data.StringVariable('Ensembl ID'), ensembl_ids)
                       ])
            self.Outputs.data.send(data_with_ensembl)

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.treeWidget.setModel(
            TreeModel([], [
                "NCBI ID", "Symbol", "Locus Tag", "Chromosome", "Description",
                "Synonyms", "Nomenclature"
            ], self.treeWidget))

        self.geneAttrComboBox.clear()
        self.Outputs.selected_genes.send(None)

    def commit(self):
        if self.data is None:
            self.Outputs.selected_genes.send(None)
            self.Outputs.data.send(None)
            return

        model = self.treeWidget.model()
        selection = self.treeWidget.selectionModel().selection()
        selection = model.mapSelectionToSource(selection)
        selectedRows = list(
            chain(*(range(r.top(),
                          r.bottom() + 1) for r in selection)))
        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)
        gene2row = dict((self.geneinfo[self.row2geneinfo[row]][0], row)
                        for row in selectedRows)

        isselected = selectedIds.__contains__

        if selectedIds:

            if self.useAttr:
                attrs = [
                    attr for attr in self.data.domain.attributes
                    if isselected(attr.name)
                ]
                domain = Orange.data.Domain(attrs, self.data.domain.class_vars,
                                            self.data.domain.metas)
                newdata = self.data.from_table(domain, self.data)

                self.Outputs.selected_genes.send(newdata)

            elif self.attributes:
                attr = self.attributes[self.gene_attr]
                gene_col = [
                    attr.str_val(v) for v in self.data.get_column_view(attr)[0]
                ]
                gene_col = [(i, name) for i, name in enumerate(gene_col)
                            if isselected(name)]
                indices = [i for i, _ in gene_col]

                # SELECTED GENES OUTPUT
                selected_genes_metas = [
                    Orange.data.StringVariable(name)
                    for name in gene.GENE_INFO_HEADER_LABELS
                ]
                selected_genes_domain = Orange.data.Domain(
                    self.data.domain.attributes, self.data.domain.class_vars,
                    self.data.domain.metas + tuple(selected_genes_metas))

                selected_genes_data = self.data.from_table(
                    selected_genes_domain, self.data)[indices]

                model_rows = [gene2row[gene_name] for _, gene_name in gene_col]
                for col, meta in zip(range(model.columnCount()),
                                     selected_genes_metas):
                    col_data = [
                        str(model.index(row, col).data(Qt.DisplayRole))
                        for row in model_rows
                    ]
                    col_data = np.array(col_data, dtype=object, ndmin=2).T
                    selected_genes_data[:, meta] = col_data

                if not len(selected_genes_data):
                    selected_genes_data = None

                self.Outputs.selected_genes.send(selected_genes_data)
        else:
            self.Outputs.selected_genes.send(None)

    def rowFiltered(self, row):
        searchStrings = self.search_string.lower().split()
        row = " ".join(self.cells[row]).lower()
        return not all([s in row for s in searchStrings])

    def searchUpdate(self):
        if not self.data:
            return
        searchStrings = self.search_string.lower().split()
        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            row = " ".join(row).lower()
            self.treeWidget.setRowHidden(
                mapFromSource(index(i, 0)).row(), QModelIndex(),
                not all([s in row for s in searchStrings]))

    def selectFiltered(self):
        if not self.data:
            return
        itemSelection = QItemSelection()

        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            if not self.rowFiltered(i):
                itemSelection.select(mapFromSource(index(i, 0)),
                                     mapFromSource(index(i, 0)))
        self.treeWidget.selectionModel().select(
            itemSelection,
            QItemSelectionModel.Select | QItemSelectionModel.Rows)

    def onAltSourceChange(self):
        self.updateInfoItems()

    def onDeleteWidget(self):
        # try to cancel pending tasks
        if self.initfuture:
            self.initfuture.cancel()
        if self.itemsfuture:
            self.itemsfuture.cancel()

        self.executor.shutdown(wait=False)
        super().onDeleteWidget()
class OWAverage(OWWidget):
    # Widget's name as displayed in the canvas
    name = "Average Spectra"

    # Short widget description
    description = ("Calculates averages.")

    icon = "icons/average.svg"

    # Define inputs and outputs
    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)

    class Outputs:
        averages = Output("Averages", Orange.data.Table, default=True)

    settingsHandler = settings.DomainContextHandler()
    group_var = settings.ContextSetting(None)

    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    class Warning(OWWidget.Warning):
        nodata = Msg("No useful data on input!")

    def __init__(self):
        super().__init__()

        self.data = None
        self.set_data(self.data)  # show warning

        self.group_vars = DomainModel(placeholder="None",
                                      separators=False,
                                      valid_types=Orange.data.DiscreteVariable)
        self.group_view = gui.listView(self.controlArea,
                                       self,
                                       "group_var",
                                       box="Group by",
                                       model=self.group_vars,
                                       callback=self.grouping_changed)

        gui.auto_commit(self.controlArea, self, "autocommit", "Apply")

    @Inputs.data
    def set_data(self, dataset):
        self.Warning.nodata.clear()
        self.closeContext()
        self.data = dataset
        self.group_var = None
        if dataset is None:
            self.Warning.nodata()
        else:
            self.group_vars.set_domain(dataset.domain)
            self.openContext(dataset.domain)

        self.commit()

    @staticmethod
    def average_table(table):
        """
        Return a features-averaged table.

        For metas and class_vars,
          - return average value of ContinuousVariable
          - return value of DiscreteVariable, StringVariable and TimeVariable
            if all are the same.
          - return unknown otherwise.
        """
        if len(table) == 0:
            return table
        mean = bottleneck.nanmean(table.X, axis=0).reshape(1, -1)
        avg_table = Orange.data.Table.from_numpy(
            table.domain,
            X=mean,
            Y=np.atleast_2d(table.Y[0].copy()),
            metas=np.atleast_2d(table.metas[0].copy()))
        cont_vars = [
            var for var in table.domain.class_vars + table.domain.metas
            if isinstance(var, Orange.data.ContinuousVariable)
        ]
        for var in cont_vars:
            index = table.domain.index(var)
            col, _ = table.get_column_view(index)
            try:
                avg_table[0, index] = np.nanmean(col)
            except AttributeError:
                # numpy.lib.nanfunctions._replace_nan just guesses and returns
                # a boolean array mask for object arrays because object arrays
                # do not support `isnan` (numpy-gh-9009)
                # Since we know that ContinuousVariable values must be np.float64
                # do an explicit cast here
                avg_table[0, index] = np.nanmean(col, dtype=np.float64)

        other_vars = [
            var for var in table.domain.class_vars + table.domain.metas
            if not isinstance(var, Orange.data.ContinuousVariable)
        ]
        for var in other_vars:
            index = table.domain.index(var)
            col, _ = table.get_column_view(index)
            val = var.to_val(avg_table[0, var])
            if not np.all(col == val):
                avg_table[0, var] = Orange.data.Unknown

        return avg_table

    def grouping_changed(self):
        """Calls commit() indirectly to respect auto_commit setting."""
        self.commit()

    def commit(self):
        averages = None
        if self.data is not None:
            if self.group_var is None:
                averages = self.average_table(self.data)
            else:
                parts = []
                for value in self.group_var.values:
                    svfilter = SameValue(self.group_var, value)
                    v_table = self.average_table(svfilter(self.data))
                    parts.append(v_table)
                # Using "None" as in OWSelectRows
                # Values is required because FilterDiscrete doesn't have
                # negate keyword or IsDefined method
                deffilter = Values(
                    conditions=[FilterDiscrete(self.group_var, None)],
                    negate=True)
                v_table = self.average_table(deffilter(self.data))
                parts.append(v_table)
                averages = Orange.data.Table.concatenate(parts, axis=0)
        self.Outputs.averages.send(averages)
示例#12
0
class OWUnique(widget.OWWidget):
    name = 'Unique'
    icon = 'icons/Unique.svg'
    description = 'Filter instances unique by specified key attribute(s).'
    category = "Transform"
    priority = 1120

    class Inputs:
        data = widget.Input("Data", Table)

    class Outputs:
        data = widget.Output("Data", Table)

    want_main_area = False

    TIEBREAKERS = {
        'Last instance':
        itemgetter(-1),
        'First instance':
        itemgetter(0),
        'Middle instance':
        lambda seq: seq[len(seq) // 2],
        'Random instance':
        np.random.choice,
        'Discard non-unique instances':
        lambda seq: seq[0] if len(seq) == 1 else None
    }

    settingsHandler = settings.DomainContextHandler()
    selected_vars = settings.ContextSetting([])
    tiebreaker = settings.Setting(next(iter(TIEBREAKERS)))
    autocommit = settings.Setting(True)

    def __init__(self):
        # Commit is thunked because autocommit redefines it
        # pylint: disable=unnecessary-lambda
        super().__init__()
        self.data = None

        self.var_model = DomainModel(parent=self, order=DomainModel.MIXED)
        var_list = gui.listView(self.controlArea,
                                self,
                                "selected_vars",
                                box="Group by",
                                model=self.var_model,
                                callback=self.commit.deferred,
                                viewType=ListViewSearch)
        var_list.setSelectionMode(var_list.ExtendedSelection)

        gui.comboBox(self.controlArea,
                     self,
                     'tiebreaker',
                     box=True,
                     label='Instance to select in each group:',
                     items=tuple(self.TIEBREAKERS),
                     callback=self.commit.deferred,
                     sendSelectedValue=True)
        gui.auto_commit(self.controlArea,
                        self,
                        'autocommit',
                        'Commit',
                        orientation=Qt.Horizontal)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.selected_vars = []
        if data:
            self.var_model.set_domain(data.domain)
            self.selected_vars = self.var_model[:]
            self.openContext(data.domain)
        else:
            self.var_model.set_domain(None)

        self.commit.now()

    @gui.deferred
    def commit(self):
        if self.data is None:
            self.Outputs.data.send(None)
        else:
            self.Outputs.data.send(self._compute_unique_data())

    def _compute_unique_data(self):
        uniques = {}
        keys = zip(*[
            self.data.get_column_view(attr)[0]
            for attr in self.selected_vars or self.var_model
        ])
        for i, key in enumerate(keys):
            uniques.setdefault(key, []).append(i)

        choose = self.TIEBREAKERS[self.tiebreaker]
        selection = sorted(x
                           for x in (choose(inds) for inds in uniques.values())
                           if x is not None)
        if selection:
            return self.data[selection]
        else:
            return None
示例#13
0
class OWImageViewer(widget.OWWidget):
    name = "Image Viewer"
    description = "View images referred to in the data."
    icon = "icons/ImageViewer.svg"
    priority = 4050

    inputs = [("Data", Orange.data.Table, "setData")]
    outputs = [(
        "Data",
        Orange.data.Table,
    )]

    settingsHandler = settings.DomainContextHandler()

    imageAttr = settings.ContextSetting(0)
    titleAttr = settings.ContextSetting(0)

    imageSize = settings.Setting(100)
    autoCommit = settings.Setting(False)

    buttons_area_orientation = Qt.Vertical
    graph_name = "scene"

    UserAdviceMessages = [
        widget.Message(
            "Pressing the 'Space' key while the thumbnail view has focus and "
            "a selected item will open a window with a full image",
            persistent_id="preview-introduction")
    ]

    def __init__(self):
        super().__init__()
        self.data = None
        self.allAttrs = []
        self.stringAttrs = []

        self.selectedIndices = []

        #: List of _ImageItems
        self.items = []

        self._errcount = 0
        self._successcount = 0

        self.info = gui.widgetLabel(gui.vBox(self.controlArea, "Info"),
                                    "Waiting for input.\n")

        self.imageAttrCB = gui.comboBox(
            self.controlArea,
            self,
            "imageAttr",
            box="Image Filename Attribute",
            tooltip="Attribute with image filenames",
            callback=[self.clearScene, self.setupScene],
            contentsLength=12,
            addSpace=True,
        )

        self.titleAttrCB = gui.comboBox(self.controlArea,
                                        self,
                                        "titleAttr",
                                        box="Title Attribute",
                                        tooltip="Attribute with image title",
                                        callback=self.updateTitles,
                                        contentsLength=12,
                                        addSpace=True)

        gui.hSlider(self.controlArea,
                    self,
                    "imageSize",
                    box="Image Size",
                    minValue=32,
                    maxValue=1024,
                    step=16,
                    callback=self.updateSize,
                    createLabel=False)
        gui.rubber(self.controlArea)

        gui.auto_commit(self.buttonsArea,
                        self,
                        "autoCommit",
                        "Send",
                        box=False)

        self.thumbnailView = ThumbnailView(
            alignment=Qt.AlignTop | Qt.AlignLeft,  # scene alignment,
            focusPolicy=Qt.StrongFocus,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn)
        self.mainArea.layout().addWidget(self.thumbnailView)
        self.scene = self.thumbnailView.scene()
        self.scene.selectionChanged.connect(self.onSelectionChanged)
        self.loader = ImageLoader(self)

    def sizeHint(self):
        return QSize(800, 600)

    def setData(self, data):
        self.closeContext()
        self.clear()

        self.data = data

        if data is not None:
            domain = data.domain
            self.allAttrs = (domain.class_vars + domain.metas +
                             domain.attributes)
            self.stringAttrs = [a for a in domain.metas if a.is_string]

            self.stringAttrs = sorted(self.stringAttrs,
                                      key=lambda attr: 0
                                      if "type" in attr.attributes else 1)

            indices = [
                i for i, var in enumerate(self.stringAttrs)
                if var.attributes.get("type") == "image"
            ]
            if indices:
                self.imageAttr = indices[0]

            self.imageAttrCB.setModel(VariableListModel(self.stringAttrs))
            self.titleAttrCB.setModel(VariableListModel(self.allAttrs))

            self.openContext(data)

            self.imageAttr = max(
                min(self.imageAttr,
                    len(self.stringAttrs) - 1), 0)
            self.titleAttr = max(min(self.titleAttr,
                                     len(self.allAttrs) - 1), 0)

            if self.stringAttrs:
                self.setupScene()
        else:
            self.info.setText("Waiting for input.\n")

    def clear(self):
        self.data = None
        self.error()
        self.imageAttrCB.clear()
        self.titleAttrCB.clear()
        self.clearScene()

    def setupScene(self):
        self.error()
        if self.data:
            attr = self.stringAttrs[self.imageAttr]
            titleAttr = self.allAttrs[self.titleAttr]
            instances = [
                inst for inst in self.data if numpy.isfinite(inst[attr])
            ]
            assert self.thumbnailView.count() == 0
            size = QSizeF(self.imageSize, self.imageSize)

            for i, inst in enumerate(instances):
                url = self.urlFromValue(inst[attr])
                title = str(inst[titleAttr])

                thumbnail = GraphicsThumbnailWidget(QPixmap(), title=title)
                thumbnail.setThumbnailSize(size)
                thumbnail.setToolTip(url.toString())
                thumbnail.instance = inst
                self.thumbnailView.addThumbnail(thumbnail)

                if url.isValid() and url.isLocalFile():
                    reader = QImageReader(url.toLocalFile())
                    image = reader.read()
                    if image.isNull():
                        error = reader.errorString()
                        thumbnail.setToolTip(thumbnail.toolTip() + "\n" +
                                             error)
                        self._errcount += 1
                    else:
                        pixmap = QPixmap.fromImage(image)
                        thumbnail.setPixmap(pixmap)
                        self._successcount += 1

                    future = Future()
                    future.set_result(image)
                    future._reply = None
                elif url.isValid():
                    future = self.loader.get(url)

                    @future.add_done_callback
                    def set_pixmap(future, thumb=thumbnail):
                        if future.cancelled():
                            return

                        assert future.done()

                        if future.exception():
                            # Should be some generic error image.
                            pixmap = QPixmap()
                            thumb.setToolTip(thumb.toolTip() + "\n" +
                                             str(future.exception()))
                        else:
                            pixmap = QPixmap.fromImage(future.result())

                        thumb.setPixmap(pixmap)

                        self._noteCompleted(future)
                else:
                    future = None

                self.items.append(_ImageItem(i, thumbnail, url, future))

            if any(it.future is not None and not it.future.done()
                   for it in self.items):
                self.info.setText("Retrieving...\n")
            else:
                self._updateStatus()

    def urlFromValue(self, value):
        variable = value.variable
        origin = variable.attributes.get("origin", "")
        if origin and QDir(origin).exists():
            origin = QUrl.fromLocalFile(origin)
        elif origin:
            origin = QUrl(origin)
            if not origin.scheme():
                origin.setScheme("file")
        else:
            origin = QUrl("")
        base = origin.path()
        if base.strip() and not base.endswith("/"):
            origin.setPath(base + "/")

        name = QUrl(str(value))
        url = origin.resolved(name)
        if not url.scheme():
            url.setScheme("file")
        return url

    def _cancelAllFutures(self):
        for item in self.items:
            if item.future is not None:
                item.future.cancel()
                if item.future._reply is not None:
                    item.future._reply.close()
                    item.future._reply.deleteLater()
                    item.future._reply = None

    def clearScene(self):
        self._cancelAllFutures()

        self.items = []
        self.thumbnailView.clear()
        self._errcount = 0
        self._successcount = 0

    def thumbnailItems(self):
        return [item.widget for item in self.items]

    def updateSize(self):
        size = QSizeF(self.imageSize, self.imageSize)
        for item in self.thumbnailItems():
            item.setThumbnailSize(size)

    def updateTitles(self):
        titleAttr = self.allAttrs[self.titleAttr]
        for item in self.items:
            item.widget.setTitle(str(item.widget.instance[titleAttr]))

    def onSelectionChanged(self):
        selected = [item for item in self.items if item.widget.isSelected()]
        self.selectedIndices = [item.index for item in selected]
        self.commit()

    def commit(self):
        if self.data:
            if self.selectedIndices:
                selected = self.data[self.selectedIndices]
            else:
                selected = None
            self.send("Data", selected)
        else:
            self.send("Data", None)

    def _noteCompleted(self, future):
        # Note the completed future's state
        if future.cancelled():
            return

        if future.exception():
            self._errcount += 1
            _log.debug("Error: %r", future.exception())
        else:
            self._successcount += 1

        self._updateStatus()

    def _updateStatus(self):
        count = len([item for item in self.items if item.future is not None])
        self.info.setText("Retrieving:\n" +
                          "{} of {} images".format(self._successcount, count))

        if self._errcount + self._successcount == count:
            if self._errcount:
                self.info.setText(
                    "Done:\n" +
                    "{} images, {} errors".format(count, self._errcount))
            else:
                self.info.setText("Done:\n" + "{} images".format(count))
            attr = self.stringAttrs[self.imageAttr]
            if self._errcount == count and "type" not in attr.attributes:
                self.error("No images found! Make sure the '%s' attribute "
                           "is tagged with 'type=image'" % attr.name)

    def onDeleteWidget(self):
        self._cancelAllFutures()
        self.clear()
示例#14
0
class OWDiscretize(widget.OWWidget):
    # pylint: disable=too-many-instance-attributes
    name = "Discretize"
    description = "Discretize the numeric data features."
    icon = "icons/Discretize.svg"
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table, doc="Input data table")

    class Outputs:
        data = Output("Data",
                      Orange.data.Table,
                      doc="Table with discretized features")

    settingsHandler = settings.DomainContextHandler()
    settings_version = 2
    saved_var_states = settings.ContextSetting({})

    #: The default method name
    default_method_name = settings.Setting(Methods.EqualFreq.name)
    #: The k for Equal{Freq,Width}
    default_k = settings.Setting(3)
    #: The default cut points for custom entry
    default_cutpoints: Tuple[float, ...] = settings.Setting(())
    autosend = settings.Setting(True)

    #: Discretization methods
    Default, Leave, MDL, EqualFreq, EqualWidth, Remove, Custom = list(Methods)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()

        #: input data
        self.data = None
        self.class_var = None
        #: Current variable discretization state
        self.var_state = {}
        #: Saved variable discretization settings (context setting)
        self.saved_var_states = {}

        self.method = Methods.Default
        self.k = 5
        self.cutpoints = ()

        box = gui.vBox(self.controlArea, self.tr("Default Discretization"))
        self._default_method_ = 0
        self.default_bbox = rbox = gui.radioButtons(
            box, self, "_default_method_", callback=self._default_disc_changed)
        self.default_button_group = bg = rbox.findChild(QButtonGroup)
        bg.buttonClicked[int].connect(self.set_default_method)

        rb = gui.hBox(rbox)
        self.left = gui.vBox(rb)
        right = gui.vBox(rb)
        rb.layout().setStretch(0, 1)
        rb.layout().setStretch(1, 1)
        self.options = [
            (Methods.Default, self.tr("Default")),
            (Methods.Leave, self.tr("Leave numeric")),
            (Methods.MDL, self.tr("Entropy-MDL discretization")),
            (Methods.EqualFreq, self.tr("Equal-frequency discretization")),
            (Methods.EqualWidth, self.tr("Equal-width discretization")),
            (Methods.Remove, self.tr("Remove numeric variables")),
            (Methods.Custom, self.tr("Manual")),
        ]

        for id_, opt in self.options[1:]:
            t = gui.appendRadioButton(rbox, opt)
            bg.setId(t, id_)
            t.setChecked(id_ == self.default_method)
            [right, self.left][opt.startswith("Equal")].layout().addWidget(t)

        def _intbox(parent, attr, callback):
            box = gui.indentedBox(parent)
            s = gui.spin(box,
                         self,
                         attr,
                         minv=2,
                         maxv=10,
                         label="Num. of intervals:",
                         callback=callback)
            s.setMaximumWidth(60)
            s.setAlignment(Qt.AlignRight)
            gui.rubber(s.box)
            return box.box

        self.k_general = _intbox(self.left, "default_k",
                                 self._default_disc_changed)
        self.k_general.layout().setContentsMargins(0, 0, 0, 0)

        def manual_cut_editline(text="", enabled=True) -> QLineEdit:
            edit = QLineEdit(
                text=text,
                placeholderText="e.g. 0.0, 0.5, 1.0",
                toolTip="Enter fixed discretization cut points (a comma "
                "separated list of strictly increasing numbers e.g. "
                "0.0, 0.5, 1.0).",
                enabled=enabled,
            )

            @edit.textChanged.connect
            def update():
                validator = edit.validator()
                if validator is not None:
                    state, _, _ = validator.validate(edit.text(), 0)
                else:
                    state = QValidator.Acceptable
                palette = edit.palette()
                colors = {
                    QValidator.Intermediate: (Qt.yellow, Qt.black),
                    QValidator.Invalid: (Qt.red, Qt.black),
                }.get(state, None)
                if colors is None:
                    palette = QPalette()
                else:
                    palette.setColor(QPalette.Base, colors[0])
                    palette.setColor(QPalette.Text, colors[1])

                cr = edit.cursorRect()
                p = edit.mapToGlobal(cr.bottomRight())
                edit.setPalette(palette)
                if state != QValidator.Acceptable and edit.isVisible():
                    show_tip(edit, p, edit.toolTip(), textFormat=Qt.RichText)
                else:
                    show_tip(edit, p, "")

            return edit

        self.manual_cuts_edit = manual_cut_editline(
            text=", ".join(map(str, self.default_cutpoints)),
            enabled=self.default_method == Methods.Custom,
        )

        def set_manual_default_cuts():
            text = self.manual_cuts_edit.text()
            self.default_cutpoints = tuple(
                float(s.strip()) for s in text.split(",") if s.strip())
            self._default_disc_changed()

        self.manual_cuts_edit.editingFinished.connect(set_manual_default_cuts)

        validator = IncreasingNumbersListValidator()
        self.manual_cuts_edit.setValidator(validator)
        ibox = gui.indentedBox(right, orientation=Qt.Horizontal)
        ibox.layout().addWidget(self.manual_cuts_edit)

        right.layout().addStretch(10)
        self.left.layout().addStretch(10)

        self.connect_control(
            "default_cutpoints",
            lambda values: self.manual_cuts_edit.setText(", ".join(
                map(str, values))))
        vlayout = QHBoxLayout()
        box = gui.widgetBox(self.controlArea,
                            "Individual Attribute Settings",
                            orientation=vlayout,
                            spacing=8)

        # List view with all attributes
        self.varview = ListViewSearch(
            selectionMode=QListView.ExtendedSelection,
            uniformItemSizes=True,
        )
        self.varview.setItemDelegate(DiscDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._var_selection_changed)

        vlayout.addWidget(self.varview)
        # Controls for individual attr settings
        self.bbox = controlbox = gui.radioButtons(
            box, self, "method", callback=self._disc_method_changed)
        vlayout.addWidget(controlbox)
        self.variable_button_group = bg = controlbox.findChild(QButtonGroup)
        for id_, opt in self.options[:5]:
            b = gui.appendRadioButton(controlbox, opt)
            bg.setId(b, id_)

        self.k_specific = _intbox(controlbox, "k", self._disc_method_changed)

        gui.appendRadioButton(controlbox,
                              "Remove attribute",
                              id=Methods.Remove)
        b = gui.appendRadioButton(controlbox, "Manual", id=Methods.Custom)

        self.manual_cuts_specific = manual_cut_editline(
            text=", ".join(map(str, self.cutpoints)),
            enabled=self.method == Methods.Custom)
        self.manual_cuts_specific.setValidator(validator)
        b.toggled[bool].connect(self.manual_cuts_specific.setEnabled)

        def set_manual_cuts():
            text = self.manual_cuts_specific.text()
            points = [t for t in text.split(",") if t.split()]
            self.cutpoints = tuple(float(t) for t in points)
            self._disc_method_changed()

        self.manual_cuts_specific.editingFinished.connect(set_manual_cuts)

        self.connect_control(
            "cutpoints",
            lambda values: self.manual_cuts_specific.setText(", ".join(
                map(str, values))))
        ibox = gui.indentedBox(controlbox, orientation=Qt.Horizontal)
        self.copy_current_to_manual_button = b = FixedSizeButton(
            text="CC",
            toolTip="Copy the current cut points to manual mode",
            enabled=False)
        b.clicked.connect(self._copy_to_manual)
        ibox.layout().addWidget(self.manual_cuts_specific)
        ibox.layout().addWidget(b)

        gui.rubber(controlbox)
        controlbox.setEnabled(False)
        bg.button(self.method)
        self.controlbox = controlbox

        gui.auto_apply(self.buttonsArea, self, "autosend")

        self._update_spin_positions()

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

    @property
    def default_method(self) -> Methods:
        return Methods[self.default_method_name]

    @default_method.setter
    def default_method(self, method):
        self.set_default_method(method)

    def set_default_method(self, method: Methods):
        if isinstance(method, int):
            method = Methods(method)
        else:
            method = Methods.from_method(method)

        if method != self.default_method:
            self.default_method_name = method.name
            self.default_button_group.button(method).setChecked(True)
            self._default_disc_changed()
        self.manual_cuts_edit.setEnabled(method == Methods.Custom)

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        if self.data is not None:
            self._initialize(data)
            self.openContext(data)
            # Restore the per variable discretization settings
            self._restore(self.saved_var_states)
            # Complete the induction of cut points
            self._update_points()
            self.info.set_input_summary(len(data),
                                        format_summary_details(data))
        else:
            self.info.set_input_summary(self.info.NoInput)
            self._clear()
        self.unconditional_commit()

    def _initialize(self, data):
        # Initialize the default variable states for new data.
        self.class_var = data.domain.class_var
        cvars = [var for var in data.domain.variables if var.is_continuous]
        self.varmodel[:] = cvars

        has_disc_class = data.domain.has_discrete_class

        def set_enabled(box: QWidget, id_: Methods, state: bool):
            bg = box.findChild(QButtonGroup)
            b = bg.button(id_)
            b.setEnabled(state)

        set_enabled(self.default_bbox, self.MDL, has_disc_class)
        bg = self.bbox.findChild(QButtonGroup)
        b = bg.button(Methods.MDL)
        b.setEnabled(has_disc_class)
        set_enabled(self.bbox, self.MDL, has_disc_class)

        # If the newly disabled MDL button is checked then change it
        if not has_disc_class and self.default_method == self.MDL:
            self.default_method = Methods.Leave
        if not has_disc_class and self.method == self.MDL:
            self.method = Methods.Default

        # Reset (initialize) the variable discretization states.
        self._reset()

    def _restore(self, saved_state):
        # Restore variable states from a saved_state dictionary.
        def_method = self._current_default_method()
        for i, var in enumerate(self.varmodel):
            key = variable_key(var)
            if key in saved_state:
                state = saved_state[key]
                if isinstance(state.method, Default):
                    state = DState(Default(def_method), None, None)
                self._set_var_state(i, state)

    def _reset(self):
        # restore the individual variable settings back to defaults.
        def_method = self._current_default_method()
        self.var_state = {}
        for i in range(len(self.varmodel)):
            state = DState(Default(def_method), None, None)
            self._set_var_state(i, state)

    def _set_var_state(self, index, state):
        # set the state of variable at `index` to `state`.
        self.var_state[index] = state
        self.varmodel.setData(self.varmodel.index(index), state, Qt.UserRole)

    def _clear(self):
        self.data = None
        self.varmodel[:] = []
        self.var_state = {}
        self.saved_var_states = {}
        self.default_button_group.button(self.MDL).setEnabled(True)
        self.variable_button_group.button(self.MDL).setEnabled(True)

    def _update_points(self):
        """
        Update the induced cut points.
        """
        if self.data is None:
            return

        def induce_cuts(method, data, var):
            dvar = _dispatch[type(method)](method, data, var)
            if dvar is None:
                # removed
                return [], None
            elif dvar is var:
                # no transformation took place
                return None, var
            elif is_discretized(dvar):
                return dvar.compute_value.points, dvar
            raise ValueError

        for i, var in enumerate(self.varmodel):
            state = self.var_state[i]
            if state.points is None and state.disc_var is None:
                points, dvar = induce_cuts(state.method, self.data, var)
                new_state = state._replace(points=points, disc_var=dvar)
                self._set_var_state(i, new_state)

    def _current_default_method(self):
        method = self.default_method
        k = self.default_k
        if method == Methods.Leave:
            def_method = Leave()
        elif method == Methods.MDL:
            def_method = MDL()
        elif method == Methods.EqualFreq:
            def_method = EqualFreq(k)
        elif method == Methods.EqualWidth:
            def_method = EqualWidth(k)
        elif method == Methods.Remove:
            def_method = Remove()
        elif method == Methods.Custom:
            def_method = Custom(self.default_cutpoints)
        else:
            assert False
        return def_method

    def _current_method(self):
        if self.method == Methods.Default:
            method = Default(self._current_default_method())
        elif self.method == Methods.Leave:
            method = Leave()
        elif self.method == Methods.MDL:
            method = MDL()
        elif self.method == Methods.EqualFreq:
            method = EqualFreq(self.k)
        elif self.method == Methods.EqualWidth:
            method = EqualWidth(self.k)
        elif self.method == Methods.Remove:
            method = Remove()
        elif self.method == Methods.Custom:
            method = Custom(self.cutpoints)
        else:
            assert False
        return method

    def _update_spin_positions(self):
        kmethods = [Methods.EqualFreq, Methods.EqualWidth]
        self.k_general.setDisabled(self.default_method not in kmethods)
        if self.default_method == Methods.EqualFreq:
            self.left.layout().insertWidget(1, self.k_general)
        elif self.default_method == Methods.EqualWidth:
            self.left.layout().insertWidget(2, self.k_general)

        self.k_specific.setDisabled(self.method not in kmethods)
        if self.method == Methods.EqualFreq:
            self.bbox.layout().insertWidget(4, self.k_specific)
        elif self.method == Methods.EqualWidth:
            self.bbox.layout().insertWidget(5, self.k_specific)

    def _default_disc_changed(self):
        self._update_spin_positions()
        method = self._current_default_method()
        state = DState(Default(method), None, None)
        for i, _ in enumerate(self.varmodel):
            if isinstance(self.var_state[i].method, Default):
                self._set_var_state(i, state)
        self._update_points()
        self.commit()

    def _disc_method_changed(self):
        self._update_spin_positions()
        indices = self.selected_indices()
        method = self._current_method()
        state = DState(method, None, None)
        for idx in indices:
            self._set_var_state(idx, state)
        self._update_points()
        self._copy_to_manual_update_enabled()
        self.commit()

    def _copy_to_manual(self):
        indices = self.selected_indices()
        # set of all methods for the current selection
        if len(indices) != 1:
            return
        index = indices[0]
        state = self.var_state[index]
        var = self.varmodel[index]
        fmt = var.repr_val
        points = state.points
        if points is None:
            points = ()
        else:
            points = tuple(state.points)
        state = state._replace(method=Custom(points),
                               points=None,
                               disc_var=None)
        self._set_var_state(index, state)
        self.method = Methods.Custom
        self.cutpoints = points
        self.manual_cuts_specific.setText(", ".join(map(fmt, points)))
        self._update_points()
        self.commit()

    def _copy_to_manual_update_enabled(self):
        indices = self.selected_indices()
        methods = [self.var_state[i].method for i in indices]
        self.copy_current_to_manual_button.setEnabled(
            len(indices) == 1 and not isinstance(methods[0], Custom))

    def _var_selection_changed(self, *_):
        self._copy_to_manual_update_enabled()
        indices = self.selected_indices()
        # set of all methods for the current selection
        methods = [self.var_state[i].method for i in indices]

        def key(method):
            if isinstance(method, Default):
                return Default, (None, )
            return type(method), tuple(method)

        mset = list(unique_everseen(methods, key=key))

        self.controlbox.setEnabled(len(mset) > 0)
        if len(mset) == 1:
            method = mset.pop()
            self.method = Methods.from_method(method)
            if isinstance(method, (EqualFreq, EqualWidth)):
                self.k = method.k
            elif isinstance(method, Custom):
                self.cutpoints = method.points
        else:
            # deselect the current button
            self.method = -1
            bg = self.controlbox.group
            button_group_reset(bg)
        self._update_spin_positions()

    def selected_indices(self):
        rows = self.varview.selectionModel().selectedRows()
        return [index.row() for index in rows]

    def method_for_index(self, index):
        state = self.var_state[index]
        return state.method

    def discretized_var(self, index):
        # type: (int) -> Optional[Orange.data.DiscreteVariable]
        state = self.var_state[index]
        if state.disc_var is not None and state.points == []:
            # Removed by MDL Entropy
            return None
        else:
            return state.disc_var

    def discretized_domain(self):
        """
        Return the current effective discretized domain.
        """
        if self.data is None:
            return None

        # a mapping of all applied changes for variables in `varmodel`
        mapping = {
            var: self.discretized_var(i)
            for i, var in enumerate(self.varmodel)
        }

        def disc_var(source):
            return mapping.get(source, source)

        # map the full input domain to the new variables (where applicable)
        attributes = [disc_var(v) for v in self.data.domain.attributes]
        attributes = [v for v in attributes if v is not None]

        class_vars = [disc_var(v) for v in self.data.domain.class_vars]
        class_vars = [v for v in class_vars if v is not None]

        domain = Orange.data.Domain(attributes,
                                    class_vars,
                                    metas=self.data.domain.metas)
        return domain

    def commit(self):
        output = None
        if self.data is not None:
            domain = self.discretized_domain()
            output = self.data.transform(domain)

        summary = len(output) if output else self.info.NoOutput
        details = format_summary_details(output) if output else ""
        self.info.set_output_summary(summary, details)
        self.Outputs.data.send(output)

    def storeSpecificSettings(self):
        super().storeSpecificSettings()
        self.saved_var_states = {
            variable_key(var): self.var_state[i]._replace(points=None,
                                                          disc_var=None)
            for i, var in enumerate(self.varmodel)
        }

    def send_report(self):
        self.report_items(
            (("Default method", self.options[self.default_method][1]), ))
        if self.varmodel:
            self.report_items(
                "Thresholds",
                [(var.name,
                  DiscDelegate.cutsText(self.var_state[i], var.repr_val)
                  or "leave numeric") for i, var in enumerate(self.varmodel)])

    @classmethod
    def migrate_settings(cls, settings, version):  # pylint: disable=redefined-outer-name
        if version is None or version < 2:
            # was stored as int indexing Methods (but offset by 1)
            default = settings.pop("default_method", 0)
            default = Methods(default + 1)
            settings["default_method_name"] = default.name
示例#15
0
class OWDiscretize(widget.OWWidget):
    name = "离散化(Discretize)"
    description = "离散化数值数据特征"
    icon = "icons/Discretize.svg"
    keywords = []

    class Inputs:
        data = Input("数据(Data)", Orange.data.Table, doc="Input data table", replaces=['Data'])

    class Outputs:
        data = Output("数据(Data)", Orange.data.Table, doc="Table with discretized features", replaces=['Data'])

    settingsHandler = settings.DomainContextHandler()
    saved_var_states = settings.ContextSetting({})

    default_method = settings.Setting(2)
    default_k = settings.Setting(3)
    autosend = settings.Setting(True)

    #: Discretization methods
    Default, Leave, MDL, EqualFreq, EqualWidth, Remove, Custom = range(7)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()

        #: input data
        self.data = None
        self.class_var = None
        #: Current variable discretization state
        self.var_state = {}
        #: Saved variable discretization settings (context setting)
        self.saved_var_states = {}

        self.method = 0
        self.k = 5
        self.cutpoints = []

        box = gui.vBox(self.controlArea, self.tr("默认离散化"))
        self.default_bbox = rbox = gui.radioButtons(
            box, self, "default_method", callback=self._default_disc_changed)
        rb = gui.hBox(rbox)
        self.left = gui.vBox(rb)
        right = gui.vBox(rb)
        rb.layout().setStretch(0, 1)
        rb.layout().setStretch(1, 1)
        options = self.options = [
            self.tr("默认"),
            self.tr("保留原值"),
            self.tr("熵MDL离散化(Entropy-MDL discretization)"),
            self.tr("等频离散化(Equal-frequency discretization)"),
            self.tr("等宽离散化(Equal-width discretization)"),
            self.tr("删除数值变量")
        ]

        for opt in options[1:]:
            t = gui.appendRadioButton(rbox, opt)
            # This condition is ugly, but it keeps the same order of
            # options for backward compatibility of saved schemata
            [right, self.left][opt.startswith("Equal")].layout().addWidget(t)
        gui.separator(right, 18, 18)

        def _intbox(parent, attr, callback):
            box = gui.indentedBox(parent)
            s = gui.spin(
                box, self, attr, minv=2, maxv=10, label="间隔数(Num. of intervals):",
                callback=callback)
            s.setMaximumWidth(60)
            s.setAlignment(Qt.AlignRight)
            gui.rubber(s.box)
            return box.box

        self.k_general = _intbox(self.left, "default_k",
                                 self._default_disc_changed)
        self.k_general.layout().setContentsMargins(0, 0, 0, 0)
        vlayout = QHBoxLayout()
        box = gui.widgetBox(
            self.controlArea, "单个属性设置",
            orientation=vlayout, spacing=8
        )

        # List view with all attributes
        self.varview = QListView(
            selectionMode=QListView.ExtendedSelection,
            uniformItemSizes=True,
        )
        self.varview.setItemDelegate(DiscDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._var_selection_changed
        )

        vlayout.addWidget(self.varview)
        # Controls for individual attr settings
        self.bbox = controlbox = gui.radioButtons(
            box, self, "method", callback=self._disc_method_changed
        )
        vlayout.addWidget(controlbox)

        for opt in options[:5]:
            gui.appendRadioButton(controlbox, opt)

        self.k_specific = _intbox(controlbox, "k", self._disc_method_changed)

        gui.appendRadioButton(controlbox, "删除属性")

        gui.rubber(controlbox)
        controlbox.setEnabled(False)

        self.controlbox = controlbox

        box = gui.auto_apply(self.controlArea, self, "autosend")
        box.button.setFixedWidth(180)
        box.layout().insertStretch(0)

        self._update_spin_positions()

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)


    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        if self.data is not None:
            self._initialize(data)
            self.openContext(data)
            # Restore the per variable discretization settings
            self._restore(self.saved_var_states)
            # Complete the induction of cut points
            self._update_points()
            self.info.set_input_summary(len(data))
        else:
            self.info.set_input_summary(self.info.NoInput)
            self._clear()
        self.unconditional_commit()

    def _initialize(self, data):
        # Initialize the default variable states for new data.
        self.class_var = data.domain.class_var
        cvars = [var for var in data.domain.variables
                 if var.is_continuous]
        self.varmodel[:] = cvars

        has_disc_class = data.domain.has_discrete_class

        self.default_bbox.buttons[self.MDL - 1].setEnabled(has_disc_class)
        self.bbox.buttons[self.MDL].setEnabled(has_disc_class)

        # If the newly disabled MDL button is checked then change it
        if not has_disc_class and self.default_method == self.MDL - 1:
            self.default_method = 0
        if not has_disc_class and self.method == self.MDL:
            self.method = 0

        # Reset (initialize) the variable discretization states.
        self._reset()

    def _restore(self, saved_state):
        # Restore variable states from a saved_state dictionary.
        def_method = self._current_default_method()
        for i, var in enumerate(self.varmodel):
            key = variable_key(var)
            if key in saved_state:
                state = saved_state[key]
                if isinstance(state.method, Default):
                    state = DState(Default(def_method), None, None)
                self._set_var_state(i, state)

    def _reset(self):
        # restore the individual variable settings back to defaults.
        def_method = self._current_default_method()
        self.var_state = {}
        for i in range(len(self.varmodel)):
            state = DState(Default(def_method), None, None)
            self._set_var_state(i, state)

    def _set_var_state(self, index, state):
        # set the state of variable at `index` to `state`.
        self.var_state[index] = state
        self.varmodel.setData(self.varmodel.index(index), state, Qt.UserRole)

    def _clear(self):
        self.data = None
        self.varmodel[:] = []
        self.var_state = {}
        self.saved_var_states = {}
        self.default_bbox.buttons[self.MDL - 1].setEnabled(True)
        self.bbox.buttons[self.MDL].setEnabled(True)

    def _update_points(self):
        """
        Update the induced cut points.
        """
        if self.data is None or not len(self.data):
            return

        def induce_cuts(method, data, var):
            dvar = _dispatch[type(method)](method, data, var)
            if dvar is None:
                # removed
                return [], None
            elif dvar is var:
                # no transformation took place
                return None, var
            elif is_discretized(dvar):
                return dvar.compute_value.points, dvar
            assert False
            return None

        for i, var in enumerate(self.varmodel):
            state = self.var_state[i]
            if state.points is None and state.disc_var is None:
                points, dvar = induce_cuts(state.method, self.data, var)
                new_state = state._replace(points=points, disc_var=dvar)
                self._set_var_state(i, new_state)

    @staticmethod
    def _method_index(method):
        return METHODS.index((type(method), ))

    def _current_default_method(self):
        method = self.default_method + 1
        k = self.default_k
        if method == OWDiscretize.Leave:
            def_method = Leave()
        elif method == OWDiscretize.MDL:
            def_method = MDL()
        elif method == OWDiscretize.EqualFreq:
            def_method = EqualFreq(k)
        elif method == OWDiscretize.EqualWidth:
            def_method = EqualWidth(k)
        elif method == OWDiscretize.Remove:
            def_method = Remove()
        else:
            assert False
        return def_method

    def _current_method(self):
        if self.method == OWDiscretize.Default:
            method = Default(self._current_default_method())
        elif self.method == OWDiscretize.Leave:
            method = Leave()
        elif self.method == OWDiscretize.MDL:
            method = MDL()
        elif self.method == OWDiscretize.EqualFreq:
            method = EqualFreq(self.k)
        elif self.method == OWDiscretize.EqualWidth:
            method = EqualWidth(self.k)
        elif self.method == OWDiscretize.Remove:
            method = Remove()
        elif self.method == OWDiscretize.Custom:
            method = Custom(self.cutpoints)
        else:
            assert False
        return method

    def _update_spin_positions(self):
        self.k_general.setDisabled(self.default_method not in [2, 3])
        if self.default_method == 2:
            self.left.layout().insertWidget(1, self.k_general)
        elif self.default_method == 3:
            self.left.layout().insertWidget(2, self.k_general)

        self.k_specific.setDisabled(self.method not in [3, 4])
        if self.method == 3:
            self.bbox.layout().insertWidget(4, self.k_specific)
        elif self.method == 4:
            self.bbox.layout().insertWidget(5, self.k_specific)

    def _default_disc_changed(self):
        self._update_spin_positions()
        method = self._current_default_method()
        state = DState(Default(method), None, None)
        for i, _ in enumerate(self.varmodel):
            if isinstance(self.var_state[i].method, Default):
                self._set_var_state(i, state)
        self._update_points()
        self.commit()

    def _disc_method_changed(self):
        self._update_spin_positions()
        indices = self.selected_indices()
        method = self._current_method()
        state = DState(method, None, None)
        for idx in indices:
            self._set_var_state(idx, state)
        self._update_points()
        self.commit()

    def _var_selection_changed(self, *_):
        indices = self.selected_indices()
        # set of all methods for the current selection
        methods = [self.var_state[i].method for i in indices]
        mset = set(methods)
        self.controlbox.setEnabled(len(mset) > 0)
        if len(mset) == 1:
            method = mset.pop()
            self.method = self._method_index(method)
            if isinstance(method, (EqualFreq, EqualWidth)):
                self.k = method.k
            elif isinstance(method, Custom):
                self.cutpoints = method.points
        else:
            # deselect the current button
            self.method = -1
            bg = self.controlbox.group
            button_group_reset(bg)
        self._update_spin_positions()

    def selected_indices(self):
        rows = self.varview.selectionModel().selectedRows()
        return [index.row() for index in rows]

    def discretized_var(self, index):
        # type: (int) -> Optional[Orange.data.DiscreteVariable]
        state = self.var_state[index]
        if state.disc_var is not None and state.points == []:
            # Removed by MDL Entropy
            return None
        else:
            return state.disc_var

    def discretized_domain(self):
        """
        Return the current effective discretized domain.
        """
        if self.data is None:
            return None

        # a mapping of all applied changes for variables in `varmodel`
        mapping = {var: self.discretized_var(i)
                   for i, var in enumerate(self.varmodel)}

        def disc_var(source):
            return mapping.get(source, source)

        # map the full input domain to the new variables (where applicable)
        attributes = [disc_var(v) for v in self.data.domain.attributes]
        attributes = [v for v in attributes if v is not None]

        class_vars = [disc_var(v) for v in self.data.domain.class_vars]
        class_vars = [v for v in class_vars if v is not None]

        domain = Orange.data.Domain(
            attributes, class_vars, metas=self.data.domain.metas
        )
        return domain

    def commit(self):
        output = None
        if self.data is not None and len(self.data):
            domain = self.discretized_domain()
            output = self.data.transform(domain)
            self.info.set_output_summary(len(output))
        else:
            self.info.set_output_summary(self.info.NoOutput)
        self.Outputs.data.send(output)

    def storeSpecificSettings(self):
        super().storeSpecificSettings()
        self.saved_var_states = {
            variable_key(var):
                self.var_state[i]._replace(points=None, disc_var=None)
            for i, var in enumerate(self.varmodel)
        }

    def send_report(self):
        self.report_items((
            ("Default method", self.options[self.default_method + 1]),))
        if self.varmodel:
            self.report_items("Thresholds", [
                (var.name,
                 DiscDelegate.cutsText(self.var_state[i]) or "leave numeric")
                for i, var in enumerate(self.varmodel)])
示例#16
0
class OWVennDiagram(widget.OWWidget):
    name = "Venn Diagram"
    description = "A graphical visualization of the overlap of data instances " \
                  "from a collection of input datasets."
    icon = "icons/VennDiagram.svg"
    priority = 280
    keywords = []
    settings_version = 2

    class Inputs:
        data = Input("Data", Table, multiple=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    class Error(widget.OWWidget.Error):
        instances_mismatch = Msg(
            "Data sets do not contain the same instances.")
        too_many_inputs = Msg("Venn diagram accepts at most five datasets.")

    class Warning(widget.OWWidget.Warning):
        renamed_vars = Msg("Some variables have been renamed "
                           "to avoid duplicates.\n{}")

    selection: list

    settingsHandler = settings.DomainContextHandler()
    # Indices of selected disjoint areas
    selection = settings.Setting([], schema_only=True)
    #: Output unique items (one output row for every unique instance `key`)
    #: or preserve all duplicates in the output.
    output_duplicates = settings.Setting(False)
    autocommit = settings.Setting(True)
    rowwise = settings.Setting(True)
    selected_feature = settings.ContextSetting(None)

    want_control_area = False
    graph_name = "scene"
    atr_types = ['attributes', 'metas', 'class_vars']
    atr_vals = {'metas': 'metas', 'attributes': 'X', 'class_vars': 'Y'}
    row_vals = {'attributes': 'x', 'class_vars': 'y', 'metas': 'metas'}

    def __init__(self):
        super().__init__()

        # Diagram update is in progress
        self._updating = False
        # Input update is in progress
        self._inputUpdate = False
        # Input datasets in the order they were 'connected'.
        self.data = {}
        # Extracted input item sets in the order they were 'connected'
        self.itemsets = {}
        # A list with 2 ** len(self.data) elements that store item sets
        # belonging to each area
        self.disjoint = []
        # A list with  2 ** len(self.data) elements that store keys of tables
        # intersected in each area
        self.area_keys = []

        # Main area view
        self.scene = QGraphicsScene()
        self.view = QGraphicsView(self.scene)
        self.view.setRenderHint(QPainter.Antialiasing)
        self.view.setBackgroundRole(QPalette.Window)
        self.view.setFrameStyle(QGraphicsView.StyledPanel)

        self.mainArea.layout().addWidget(self.view)
        self.vennwidget = VennDiagram()
        self._resize()
        self.vennwidget.itemTextEdited.connect(self._on_itemTextEdited)
        self.scene.selectionChanged.connect(self._on_selectionChanged)

        self.scene.addItem(self.vennwidget)

        controls = gui.hBox(self.mainArea)
        box = gui.radioButtonsInBox(controls,
                                    self,
                                    'rowwise', [
                                        "Columns (features)",
                                        "Rows (instances), matched by",
                                    ],
                                    box="Elements",
                                    callback=self._on_matching_changed)
        gui.comboBox(gui.indentedBox(box),
                     self,
                     "selected_feature",
                     model=itemmodels.VariableListModel(
                         placeholder="Instance identity"),
                     callback=self._on_inputAttrActivated)
        box.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)

        self.outputs_box = box = gui.vBox(controls, "Output")
        self.output_duplicates_cb = gui.checkBox(
            box,
            self,
            "output_duplicates",
            "Output duplicates",
            callback=lambda: self.commit())  # pylint: disable=unnecessary-lambda
        gui.auto_send(box, self, "autocommit", box=False)
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        self._queue = []

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._resize()

    def showEvent(self, event):
        super().showEvent(event)
        self._resize()

    def _resize(self):
        # vennwidget draws so that the diagram fits into its geometry,
        # while labels take further 120 pixels, hence -120 in below formula
        size = max(200, min(self.view.width(), self.view.height()) - 120)
        self.vennwidget.resize(size, size)
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    @Inputs.data
    @check_sql_input
    def setData(self, data, key=None):
        self.Error.too_many_inputs.clear()
        if not self._inputUpdate:
            self._inputUpdate = True
        if key in self.data:
            if data is None:
                # Remove the input
                # Clear possible warnings.
                self.Warning.clear()
                del self.data[key]
            else:
                # Update existing item
                self.data[key] = self.data[key]._replace(name=data.name,
                                                         table=data)

        elif data is not None:
            # TODO: Allow setting more them 5 inputs and let the user
            # select the 5 to display.
            if len(self.data) == 5:
                self.Error.too_many_inputs()
                return
            # Add a new input
            self.data[key] = _InputData(key, data.name, data)
        self._setInterAttributes()

    def data_equality(self):
        """ Checks if all input datasets have same ids. """
        if not self.data.values():
            return True
        sets = []
        for val in self.data.values():
            sets.append(set(val.table.ids))
        inter = reduce(set.intersection, sets)
        return len(inter) == max(map(len, sets))

    def settings_compatible(self):
        self.Error.instances_mismatch.clear()
        if not self.rowwise:
            if not self.data_equality():
                self.vennwidget.clear()
                self.Error.instances_mismatch()
                self.itemsets = {}
                return False
        return True

    def handleNewSignals(self):
        self._inputUpdate = False
        self.vennwidget.clear()
        if not self.settings_compatible():
            self.invalidateOutput()
            return

        self._createItemsets()
        self._createDiagram()
        # If autocommit is enabled, _createDiagram already outputs data
        # If not, call unconditional_commit from here
        if not self.autocommit:
            self.unconditional_commit()

        self._updateInfo()
        super().handleNewSignals()

    def intersectionStringAttrs(self):
        sets = [
            set(string_attributes(data_.table.domain))
            for data_ in self.data.values()
        ]
        if sets:
            return reduce(set.intersection, sets)
        return set()

    def _setInterAttributes(self):
        model = self.controls.selected_feature.model()
        model[:] = [None] + list(self.intersectionStringAttrs())
        if self.selected_feature:
            names = (var.name for var in model if var)
            if self.selected_feature.name not in names:
                self.selected_feature = model[0]

    def _itemsForInput(self, key):
        """
        Calculates input for venn diagram, according to user's settings.
        """
        table = self.data[key].table
        attr = self.selected_feature
        if attr:
            return [
                str(inst[attr]) for inst in table if not np.isnan(inst[attr])
            ]
        else:
            return list(table.ids)

    def _createItemsets(self):
        """
        Create itemsets over rows or columns (domains) of input tables.
        """
        olditemsets = dict(self.itemsets)
        self.itemsets.clear()

        for key, input_ in self.data.items():
            if self.rowwise:
                items = self._itemsForInput(key)
            else:
                items = [el.name for el in input_.table.domain.attributes]
            name = input_.name
            if key in olditemsets and olditemsets[key].name == name:
                # Reuse the title (which might have been changed by the user)
                title = olditemsets[key].title
            else:
                title = name

            itemset = _ItemSet(key=key, name=name, title=title, items=items)
            self.itemsets[key] = itemset

    def _createDiagram(self):
        self._updating = True

        oldselection = list(self.selection)

        n = len(self.itemsets)
        self.disjoint, self.area_keys = \
            self.get_disjoint(set(s.items) for s in self.itemsets.values())

        vennitems = []
        colors = colorpalettes.LimitedDiscretePalette(n, force_hsv=True)

        for i, item in enumerate(self.itemsets.values()):
            cnt = len(set(item.items))
            cnt_all = len(item.items)
            if cnt != cnt_all:
                fmt = '{} <i>(all: {})</i>'
            else:
                fmt = '{}'
            counts = fmt.format(cnt, cnt_all)
            gr = VennSetItem(text=item.title, informativeText=counts)
            color = colors[i]
            color.setAlpha(100)
            gr.setBrush(QBrush(color))
            gr.setPen(QPen(Qt.NoPen))
            vennitems.append(gr)

        self.vennwidget.setItems(vennitems)

        for i, area in enumerate(self.vennwidget.vennareas()):
            area_items = list(map(str, list(self.disjoint[i])))
            if i:
                area.setText("{0}".format(len(area_items)))

            label = disjoint_set_label(i, n, simplify=False)
            head = "<h4>|{}| = {}</h4>".format(label, len(area_items))
            if len(area_items) > 32:
                items_str = ", ".join(map(escape, area_items[:32]))
                hidden = len(area_items) - 32
                tooltip = (
                    "{}<span>{}, ...</br>({} items not shown)<span>".format(
                        head, items_str, hidden))
            elif area_items:
                tooltip = "{}<span>{}</span>".format(
                    head, ", ".join(map(escape, area_items)))
            else:
                tooltip = head

            area.setToolTip(tooltip)

            area.setPen(QPen(QColor(10, 10, 10, 200), 1.5))
            area.setFlag(QGraphicsPathItem.ItemIsSelectable, True)
            area.setSelected(i in oldselection)

        self._updating = False
        self._on_selectionChanged()

    def _updateInfo(self):
        # Clear all warnings
        self.warning()

        if self.selected_feature is None:
            no_idx = [
                "#{}".format(i + 1) for i, key in enumerate(self.data)
                if not source_attributes(self.data[key].table.domain)
            ]
            if len(no_idx) == 1:
                self.warning("Dataset {} has no suitable identifiers.".format(
                    no_idx[0]))
            elif len(no_idx) > 1:
                self.warning(
                    "Datasets {} and {} have no suitable identifiers.".format(
                        ", ".join(no_idx[:-1]), no_idx[-1]))

    def _on_selectionChanged(self):
        if self._updating:
            return

        areas = self.vennwidget.vennareas()
        self.selection = [
            i for i, area in enumerate(areas) if area.isSelected()
        ]
        self.invalidateOutput()

    def _on_matching_changed(self):
        self.output_duplicates_cb.setEnabled(bool(self.rowwise))
        if not self.settings_compatible():
            self.invalidateOutput()
            return
        self._createItemsets()
        self._createDiagram()
        self._updateInfo()

    def _on_inputAttrActivated(self):
        self.rowwise = 1
        self._on_matching_changed()

    def _on_itemTextEdited(self, index, text):
        text = str(text)
        key = list(self.itemsets)[index]
        self.itemsets[key] = self.itemsets[key]._replace(title=text)

    def invalidateOutput(self):
        self.commit()

    def merge_data(self, domain, values, ids=None):
        X, metas, class_vars = None, None, None
        renamed = []
        for val in domain.values():
            names = [var.name for var in val]
            unique_names = get_unique_names_duplicates(names)
            for n, u, idx, var in zip(names, unique_names, count(), val):
                if n != u:
                    val[idx] = var.copy(name=u)
                    renamed.append(n)
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        if 'attributes' in values:
            X = np.hstack(values['attributes'])
        if 'metas' in values:
            metas = np.hstack(values['metas'])
        if 'class_vars' in values:
            class_vars = np.hstack(values['class_vars'])
        table = Table.from_numpy(Domain(**domain), X, class_vars, metas)
        if ids is not None:
            table.ids = ids
        return table

    def extract_columnwise(self, var_dict, columns=None):
        #for columns
        domain = defaultdict(list)
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                is_selected = bool(columns) and var_name.name in columns
                if var_data[0]:
                    #columns are different, copy all, rename them
                    for var, table_key in var_data[1]:
                        idx = list(self.data).index(table_key) + 1
                        new_atr = var.copy(name=f'{var_name.name} ({idx})')
                        if columns and atr_type == 'attributes':
                            new_atr.attributes['Selected'] = is_selected
                        domain[atr_type].append(new_atr)
                        renamed.append(var_name.name)
                        values[atr_type].append(
                            getattr(self.data[table_key].table[:, var_name],
                                    self.atr_vals[atr_type]).reshape(-1, 1))
                else:
                    new_atr = var_data[1][0][0].copy()
                    if columns and atr_type == 'attributes':
                        new_atr.attributes['Selected'] = is_selected
                    domain[atr_type].append(new_atr)
                    values[atr_type].append(
                        getattr(
                            self.data[var_data[1][0][1]].table[:, var_name],
                            self.atr_vals[atr_type]).reshape(-1, 1))
        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        return self.merge_data(domain, values)

    def curry_merge(self, table_key, atr_type, ids=None, selection=False):
        if self.rowwise:
            check_equality = self.arrays_equal_rows
        else:
            check_equality = self.arrays_equal_cols

        def inner(new_atrs, atr):
            """
            Atrs - list of variables we wish to merge
            new_atrs - dictionary where key is old var, val
                is [is_different:bool, table_keys:list]), is_different is set to True,
                if we are outputing duplicates, but the value is arbitrary
            """
            if atr in new_atrs:
                if not selection and self.output_duplicates:
                    #if output_duplicates, we just check if compute value is the same
                    new_atrs[atr][0] = True
                elif not new_atrs[atr][0]:
                    for var, key in new_atrs[atr][1]:
                        if not check_equality(table_key, key, atr.name,
                                              self.atr_vals[atr_type],
                                              type(var), ids):
                            new_atrs[atr][0] = True
                            break
                new_atrs[atr][1].append((atr, table_key))
            else:
                new_atrs[atr] = [False, [(atr, table_key)]]
            return new_atrs

        return inner

    def arrays_equal_rows(self, key1, key2, name, data_type, type_, ids):
        #gets masks, compares same as cols
        t1 = self.data[key1].table
        t2 = self.data[key2].table
        inter_val = set(ids[key1]) & set(ids[key2])
        t1_inter = [ids[key1][val] for val in inter_val]
        t2_inter = [ids[key2][val] for val in inter_val]
        return arrays_equal(
            getattr(t1[t1_inter, name], data_type).reshape(-1, 1),
            getattr(t2[t2_inter, name], data_type).reshape(-1, 1), type_)

    def arrays_equal_cols(self, key1, key2, name, data_type, type_, _ids=None):
        return arrays_equal(getattr(self.data[key1].table[:, name], data_type),
                            getattr(self.data[key2].table[:, name], data_type),
                            type_)

    def create_from_columns(self, columns, relevant_keys, get_selected):
        """
        Columns are duplicated only if values differ (even
        if only in order of values), origin table name and input slot is added to column name.
        """
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_keys:
                table = self.data[table_key].table
                if atr_type == 'attributes':
                    if get_selected:
                        atrs = list(
                            compress(table.domain.attributes, [
                                c.name in columns
                                for c in table.domain.attributes
                            ]))
                    else:
                        atrs = getattr(table.domain, atr_type)
                else:
                    atrs = getattr(table.domain, atr_type)
                merge_vars = self.curry_merge(table_key, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container

        if get_selected:
            annotated = self.extract_columnwise(var_dict, None)
        else:
            annotated = self.extract_columnwise(var_dict, columns)

        return annotated

    def extract_rowwise(self, var_dict, ids=None, selection=False):
        """
        keys : ['attributes', 'metas', 'class_vars']
        vals: new_atrs - dictionary where key is old name, val
            is [is_different:bool, table_keys:list])
        ids: dict with ids for each table
        """
        all_ids = sorted(
            reduce(set.union, [set(val) for val in ids.values()], set()))

        permutations = {}
        for table_key, dict_ in ids.items():
            permutations[table_key] = get_perm(list(dict_), all_ids)

        domain = defaultdict(list)
        values = defaultdict(list)
        renamed = []
        for atr_type, vars_dict in var_dict.items():
            for var_name, var_data in vars_dict.items():
                different = var_data[0]
                if different:
                    # Columns are different, copy and rename them.
                    # Renaming is done here to mark appropriately the source table.
                    # Additional strange clashes are checked later in merge_data
                    for var, table_key in var_data[1]:
                        temp = self.data[table_key].table
                        idx = list(self.data).index(table_key) + 1
                        domain[atr_type].append(
                            var.copy(name='{} ({})'.format(var_name, idx)))
                        renamed.append(var_name.name)
                        v = getattr(
                            temp[list(ids[table_key].values()), var_name],
                            self.atr_vals[atr_type])
                        perm = permutations[table_key]
                        if len(v) < len(all_ids):
                            values[atr_type].append(
                                pad_columns(v, perm, len(all_ids)))
                        else:
                            values[atr_type].append(v[perm].reshape(-1, 1))
                else:
                    value = np.full((len(all_ids), 1), np.nan)
                    domain[atr_type].append(var_data[1][0][0].copy())
                    for _, table_key in var_data[1]:
                        #different tables have different part of the same attribute vector
                        perm = permutations[table_key]
                        v = getattr(
                            self.data[table_key].table[
                                list(ids[table_key].values()), var_name],
                            self.atr_vals[atr_type]).reshape(-1, 1)
                        value = value.astype(v.dtype, copy=False)
                        value[perm] = v
                    values[atr_type].append(value)

        if renamed:
            self.Warning.renamed_vars(', '.join(renamed))
        idas = None if self.selected_feature else np.array(all_ids)
        table = self.merge_data(domain, values, idas)
        if selection:
            mask = [idx in self.selected_items for idx in all_ids]
            return create_annotated_table(table, mask)
        return table

    def get_indices(self, table, selection):
        """Returns mappings of ids (be it row id or string) to indices in tables"""
        if self.selected_feature:
            if self.output_duplicates and selection:
                items, inverse = np.unique(getattr(
                    table[:, self.selected_feature], 'metas'),
                                           return_inverse=True)
                ids = [
                    np.nonzero(inverse == idx)[0] for idx in range(len(items))
                ]
            else:
                items, ids = np.unique(getattr(table[:, self.selected_feature],
                                               'metas'),
                                       return_index=True)

        else:
            items = table.ids
            ids = range(len(table))

        if selection:
            return {
                item: idx
                for item, idx in zip(items, ids) if item in self.selected_items
            }

        return dict(zip(items, ids))

    def get_indices_to_match_by(self, relevant_keys, selection=False):
        dict_ = {}
        for key in relevant_keys:
            table = self.data[key].table
            dict_[key] = self.get_indices(table, selection)
        return dict_

    def create_from_rows(self, relevant_ids, selection=False):
        var_dict = {}
        for atr_type in self.atr_types:
            container = {}
            for table_key in relevant_ids:
                merge_vars = self.curry_merge(table_key, atr_type,
                                              relevant_ids, selection)
                atrs = getattr(self.data[table_key].table.domain, atr_type)
                container = reduce(merge_vars, atrs, container)
            var_dict[atr_type] = container
        if self.output_duplicates and not selection:
            return self.extract_rowwise_duplicates(var_dict, relevant_ids)
        return self.extract_rowwise(var_dict, relevant_ids, selection)

    def expand_table(self, table, atrs, metas, cv):
        exp = []
        n = 1 if isinstance(table, RowInstance) else len(table)
        if isinstance(table, RowInstance):
            ids = table.id.reshape(-1, 1)
            atr_vals = self.row_vals
        else:
            ids = table.ids.reshape(-1, 1)
            atr_vals = self.atr_vals
        for all_el, atr_type in zip([atrs, metas, cv], self.atr_types):
            cur_el = getattr(table.domain, atr_type)
            array = np.full((n, len(all_el)), np.nan)
            if cur_el:
                perm = get_perm(cur_el, all_el)
                b = getattr(table,
                            atr_vals[atr_type]).reshape(len(array), len(perm))
                array = array.astype(b.dtype, copy=False)
                array[:, perm] = b
            exp.append(array)
        return (*exp, ids)

    def extract_rowwise_duplicates(self, var_dict, ids):
        all_ids = sorted(
            reduce(set.union, [set(val) for val in ids.values()], set()))
        sort_key = attrgetter("name")
        all_atrs = sorted(var_dict['attributes'], key=sort_key)
        all_metas = sorted(var_dict['metas'], key=sort_key)
        all_cv = sorted(var_dict['class_vars'], key=sort_key)

        all_x, all_y, all_m = [], [], []
        new_table_ids = []
        for idx in all_ids:
            #iterate trough tables with same idx
            for table_key, t_indices in ids.items():
                if idx not in t_indices:
                    continue
                map_ = t_indices[idx]
                extracted = self.data[table_key].table[map_]
                # pylint: disable=unbalanced-tuple-unpacking
                x, m, y, t_ids = self.expand_table(extracted, all_atrs,
                                                   all_metas, all_cv)
                all_x.append(x)
                all_y.append(y)
                all_m.append(m)
                new_table_ids.append(t_ids)
        domain = {
            'attributes': all_atrs,
            'metas': all_metas,
            'class_vars': all_cv
        }
        values = {
            'attributes': [np.vstack(all_x)],
            'metas': [np.vstack(all_m)],
            'class_vars': [np.vstack(all_y)]
        }
        return self.merge_data(domain, values, np.vstack(new_table_ids))

    def commit(self):
        if not self.vennwidget.vennareas() or not self.data:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            return

        self.selected_items = reduce(
            set.union, [self.disjoint[index] for index in self.selection],
            set())
        selected_keys = reduce(
            set.union, [set(self.area_keys[area]) for area in self.selection],
            set())
        selected = None

        if self.rowwise:
            if self.selected_items:
                selected_ids = self.get_indices_to_match_by(
                    selected_keys, bool(self.selection))
                selected = self.create_from_rows(selected_ids, False)
            annotated_ids = self.get_indices_to_match_by(self.data)
            annotated = self.create_from_rows(annotated_ids, True)
        else:
            annotated = self.create_from_columns(self.selected_items,
                                                 self.data, False)
            if self.selected_items:
                selected = self.create_from_columns(self.selected_items,
                                                    selected_keys, True)

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def send_report(self):
        self.report_plot()

    def get_disjoint(self, sets):
        """
        Return all disjoint subsets.
        """
        sets = list(sets)
        n = len(sets)
        disjoint_sets = [None] * (2**n)
        included_tables = [None] * (2**n)
        for i in range(2**n):
            key = setkey(i, n)
            included = [s for s, inc in zip(sets, key) if inc]
            if included:
                excluded = [s for s, inc in zip(sets, key) if not inc]
                s = reduce(set.intersection, included)
                s = reduce(set.difference, excluded, s)
            else:
                s = set()
            disjoint_sets[i] = s
            included_tables[i] = [k for k, inc in zip(self.data, key) if inc]

        return disjoint_sets, included_tables
示例#17
0
class OWImpute(OWWidget):
    name = "Impute"
    description = "Impute missing values in the data table."
    icon = "icons/Impute.svg"
    priority = 2130
    keywords = ["substitute", "missing"]

    class Inputs:
        data = Input("Data", Orange.data.Table)
        learner = Input("Learner", Learner)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    class Error(OWWidget.Error):
        imputation_failed = Msg("Imputation failed for '{}'")
        model_based_imputer_sparse = \
            Msg("Model based imputer does not work for sparse data")

    class Warning(OWWidget.Warning):
        cant_handle_var = Msg("Default method can not handle '{}'")

    settingsHandler = settings.DomainContextHandler()

    _default_method_index = settings.Setting(int(Method.Leave))  # type: int
    # Per-variable imputation state (synced in storeSpecificSettings)
    _variable_imputation_state = settings.ContextSetting(
        {})  # type: VariableState

    autocommit = settings.Setting(True)
    default_numeric_value = settings.Setting(0.0)
    default_time = settings.Setting(0)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()
        self.data = None  # type: Optional[Orange.data.Table]
        self.learner = None  # type: Optional[Learner]
        self.default_learner = SimpleTreeLearner(min_instances=10,
                                                 max_depth=10)
        self.modified = False
        self.executor = qconcurrent.ThreadExecutor(self)
        self.__task = None

        main_layout = self.controlArea.layout()

        box = gui.vBox(self.controlArea, "Default Method")

        box_layout = QGridLayout()
        box_layout.setSpacing(8)
        box.layout().addLayout(box_layout)

        button_group = QButtonGroup()
        button_group.buttonClicked[int].connect(self.set_default_method)

        for i, (method, _) in enumerate(list(METHODS.items())[1:-1]):
            imputer = self.create_imputer(method)
            button = QRadioButton(imputer.name)
            button.setChecked(method == self.default_method_index)
            button_group.addButton(button, method)
            box_layout.addWidget(button, i % 3, i // 3)

        def set_default_time(datetime):
            datetime = datetime.toSecsSinceEpoch()
            if datetime != self.default_time:
                self.default_time = datetime
                if self.default_method_index == Method.Default:
                    self._invalidate()

        hlayout = QHBoxLayout()
        box.layout().addLayout(hlayout)
        button = QRadioButton("Fixed values; numeric variables:")
        button_group.addButton(button, Method.Default)
        button.setChecked(Method.Default == self.default_method_index)
        hlayout.addWidget(button)

        self.numeric_value_widget = DoubleSpinBox(
            minimum=DBL_MIN,
            maximum=DBL_MAX,
            singleStep=.1,
            value=self.default_numeric_value,
            alignment=Qt.AlignRight,
            enabled=self.default_method_index == Method.Default,
        )
        self.numeric_value_widget.editingFinished.connect(
            self.__on_default_numeric_value_edited)
        self.connect_control("default_numeric_value",
                             self.numeric_value_widget.setValue)
        hlayout.addWidget(self.numeric_value_widget)

        hlayout.addWidget(QLabel(", time:"))

        self.time_widget = gui.DateTimeEditWCalendarTime(self)
        self.time_widget.setEnabled(
            self.default_method_index == Method.Default)
        self.time_widget.setKeyboardTracking(False)
        self.time_widget.setContentsMargins(0, 0, 0, 0)
        self.time_widget.set_datetime(
            QDateTime.fromSecsSinceEpoch(self.default_time))
        self.connect_control(
            "default_time", lambda value: self.time_widget.set_datetime(
                QDateTime.fromSecsSinceEpoch(value)))
        self.time_widget.dateTimeChanged.connect(set_default_time)
        hlayout.addWidget(self.time_widget)

        self.default_button_group = button_group

        box = gui.hBox(self.controlArea,
                       self.tr("Individual Attribute Settings"),
                       flat=False)

        self.varview = ListViewSearch(
            selectionMode=QListView.ExtendedSelection, uniformItemSizes=True)
        self.varview.setItemDelegate(DisplayFormatDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._on_var_selection_changed)
        self.selection = self.varview.selectionModel()

        box.layout().addWidget(self.varview)
        vertical_layout = QVBoxLayout(margin=0)

        self.methods_container = QWidget(enabled=False)
        method_layout = QVBoxLayout(margin=0)
        self.methods_container.setLayout(method_layout)

        button_group = QButtonGroup()
        for method in Method:
            imputer = self.create_imputer(method)
            button = QRadioButton(text=imputer.name)
            button_group.addButton(button, method)
            method_layout.addWidget(button)

        self.value_combo = QComboBox(
            minimumContentsLength=8,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            activated=self._on_value_selected)
        self.value_double = DoubleSpinBox(
            editingFinished=self._on_value_selected,
            minimum=DBL_MIN,
            maximum=DBL_MAX,
            singleStep=.1,
        )
        self.value_stack = value_stack = QStackedWidget()
        value_stack.addWidget(self.value_combo)
        value_stack.addWidget(self.value_double)
        method_layout.addWidget(value_stack)

        button_group.buttonClicked[int].connect(
            self.set_method_for_current_selection)

        self.reset_button = QPushButton(
            "Restore All to Default",
            enabled=False,
            default=False,
            autoDefault=False,
            clicked=self.reset_variable_state,
        )

        vertical_layout.addWidget(self.methods_container)
        vertical_layout.addStretch(2)
        vertical_layout.addWidget(self.reset_button)

        box.layout().addLayout(vertical_layout)

        self.variable_button_group = button_group

        gui.auto_apply(self.buttonsArea, self, "autocommit")

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

    def create_imputer(self, method, *args):
        # type: (Method, ...) -> impute.BaseImputeMethod
        if method == Method.Model:
            if self.learner is not None:
                return impute.Model(self.learner)
            else:
                return impute.Model(self.default_learner)
        elif method == Method.AsAboveSoBelow:
            assert self.default_method_index != Method.AsAboveSoBelow
            default = self.create_imputer(Method(self.default_method_index))
            m = AsDefault()
            m.method = default
            return m
        elif method == Method.Default and not args:  # global default values
            return impute.FixedValueByType(
                default_continuous=self.default_numeric_value,
                default_time=self.default_time)
        else:
            return METHODS[method](*args)

    @property
    def default_method_index(self):
        return self._default_method_index

    @default_method_index.setter
    def default_method_index(self, index):
        if self._default_method_index != index:
            assert index != Method.AsAboveSoBelow
            self._default_method_index = index
            self.default_button_group.button(index).setChecked(True)
            self.time_widget.setEnabled(index == Method.Default)
            self.numeric_value_widget.setEnabled(index == Method.Default)
            # update variable view
            self.update_varview()
            self._invalidate()

    def set_default_method(self, index):
        """Set the current selected default imputation method.
        """
        self.default_method_index = index

    def __on_default_numeric_value_edited(self):
        val = self.numeric_value_widget.value()
        if val != self.default_numeric_value:
            self.default_numeric_value = val
            if self.default_method_index == Method.Default:
                self._invalidate()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.cancel()
        self.closeContext()
        self.varmodel[:] = []
        self._variable_imputation_state = {}  # type: VariableState
        self.modified = False
        self.data = data

        if data is not None:
            self.varmodel[:] = data.domain.variables
            self.openContext(data.domain)
            # restore per variable imputation state
            self._restore_state(self._variable_imputation_state)
        self.reset_button.setEnabled(len(self.varmodel) > 0)
        summary = len(data) if data else self.info.NoInput
        details = format_summary_details(data) if data else ""
        self.info.set_input_summary(summary, details)

        self.update_varview()
        self.unconditional_commit()

    @Inputs.learner
    def set_learner(self, learner):
        self.cancel()
        self.learner = learner or self.default_learner
        imputer = self.create_imputer(Method.Model)
        button = self.default_button_group.button(Method.Model)
        button.setText(imputer.name)

        variable_button = self.variable_button_group.button(Method.Model)
        variable_button.setText(imputer.name)

        if learner is not None:
            self.default_method_index = Method.Model

        self.update_varview()
        self.commit()

    def get_method_for_column(self, column_index):
        # type: (int) -> impute.BaseImputeMethod
        """
        Return the imputation method for column by its index.
        """
        assert 0 <= column_index < len(self.varmodel)
        idx = self.varmodel.index(column_index, 0)
        state = idx.data(StateRole)
        if state is None:
            state = (Method.AsAboveSoBelow, ())
        return self.create_imputer(state[0], *state[1])

    def _invalidate(self):
        self.modified = True
        if self.__task is not None:
            self.cancel()
        self.commit()

    def commit(self):
        self.cancel()
        self.warning()
        self.Error.imputation_failed.clear()
        self.Error.model_based_imputer_sparse.clear()
        summary = len(self.data) if self.data else self.info.NoOutput
        detail = format_summary_details(self.data) if self.data else ""
        self.info.set_output_summary(summary, detail)

        if not self.data or not self.varmodel.rowCount():
            self.Outputs.data.send(self.data)
            self.modified = False
            return

        data = self.data
        impute_state = [(i, var, self.get_method_for_column(i))
                        for i, var in enumerate(self.varmodel)]
        # normalize to the effective method bypasing AsDefault
        impute_state = [(i, var, m.method if isinstance(m, AsDefault) else m)
                        for i, var, m in impute_state]

        def impute_one(method, var, data):
            # type: (impute.BaseImputeMethod, Variable, Table) -> Any
            # Readability counts, pylint: disable=no-else-raise
            if isinstance(method, impute.Model) and data.is_sparse():
                raise SparseNotSupported()
            elif isinstance(method, impute.DropInstances):
                return RowMask(method(data, var))
            elif not method.supports_variable(var):
                raise VariableNotSupported(var)
            else:
                return method(data, var)

        futures = []
        for _, var, method in impute_state:
            f = self.executor.submit(impute_one, copy.deepcopy(method), var,
                                     data)
            futures.append(f)

        w = qconcurrent.FutureSetWatcher(futures)
        w.doneAll.connect(self.__commit_finish)
        w.progressChanged.connect(self.__progress_changed)
        self.__task = Task(futures, w)
        self.progressBarInit()
        self.setInvalidated(True)

    @Slot()
    def __commit_finish(self):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        futures = self.__task.futures
        assert len(futures) == len(self.varmodel)
        assert self.data is not None

        def get_variable(variable, future, drop_mask) \
                -> Optional[List[Orange.data.Variable]]:
            # Returns a (potentially empty) list of variables,
            # or None on failure that should interrupt the imputation
            assert future.done()
            try:
                res = future.result()
            except SparseNotSupported:
                self.Error.model_based_imputer_sparse()
                return []  # None?
            except VariableNotSupported:
                self.Warning.cant_handle_var(variable.name)
                return []
            except Exception:  # pylint: disable=broad-except
                log = logging.getLogger(__name__)
                log.info("Error for %s", variable.name, exc_info=True)
                self.Error.imputation_failed(variable.name)
                return None
            if isinstance(res, RowMask):
                drop_mask |= res.mask
                newvar = variable
            else:
                newvar = res
            if isinstance(newvar, Orange.data.Variable):
                newvar = [newvar]
            return newvar

        def create_data(attributes, class_vars):
            domain = Orange.data.Domain(attributes, class_vars,
                                        self.data.domain.metas)
            try:
                return self.data.from_table(domain, self.data[~drop_mask])
            except Exception:  # pylint: disable=broad-except
                log = logging.getLogger(__name__)
                log.info("Error", exc_info=True)
                self.Error.imputation_failed("Unknown")
                return None

        self.__task = None
        self.setInvalidated(False)
        self.progressBarFinished()

        attributes = []
        class_vars = []
        drop_mask = np.zeros(len(self.data), bool)
        for i, (var, fut) in enumerate(zip(self.varmodel, futures)):
            newvar = get_variable(var, fut, drop_mask)
            if newvar is None:
                data = None
                break
            if i < len(self.data.domain.attributes):
                attributes.extend(newvar)
            else:
                class_vars.extend(newvar)
        else:
            data = create_data(attributes, class_vars)

        self.Outputs.data.send(data)
        self.modified = False
        summary = len(data) if data else self.info.NoOutput
        details = format_summary_details(data) if data else ""
        self.info.set_output_summary(summary, details)

    @Slot(int, int)
    def __progress_changed(self, n, d):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        self.progressBarSet(100. * n / d)

    def cancel(self):
        self.__cancel(wait=False)

    def __cancel(self, wait=False):
        if self.__task is not None:
            task, self.__task = self.__task, None
            task.cancel()
            task.watcher.doneAll.disconnect(self.__commit_finish)
            task.watcher.progressChanged.disconnect(self.__progress_changed)
            if wait:
                concurrent.futures.wait(task.futures)
                task.watcher.flush()
            self.progressBarFinished()
            self.setInvalidated(False)

    def onDeleteWidget(self):
        self.__cancel(wait=True)
        super().onDeleteWidget()

    def send_report(self):
        specific = []
        for i, var in enumerate(self.varmodel):
            method = self.get_method_for_column(i)
            if not isinstance(method, AsDefault):
                specific.append("{} ({})".format(var.name, str(method)))

        default = self.create_imputer(Method.AsAboveSoBelow)
        if specific:
            self.report_items((("Default method", default.name),
                               ("Specific imputers", ", ".join(specific))))
        else:
            self.report_items((("Method", default.name), ))

    def _on_var_selection_changed(self):
        # Method is well documented, splitting it is not needed for readability,
        # thus pylint: disable=too-many-branches
        indexes = self.selection.selectedIndexes()
        self.methods_container.setEnabled(len(indexes) > 0)
        defmethod = (Method.AsAboveSoBelow, ())
        methods = [index.data(StateRole) for index in indexes]
        methods = [m if m is not None else defmethod for m in methods]
        methods = set(methods)
        selected_vars = [self.varmodel[index.row()] for index in indexes]
        has_discrete = any(var.is_discrete for var in selected_vars)
        fixed_value = None
        value_stack_enabled = False
        current_value_widget = None

        if len(methods) == 1:
            method_type, parameters = methods.pop()
            for m in Method:
                if method_type == m:
                    self.variable_button_group.button(m).setChecked(True)

            if method_type == Method.Default:
                (fixed_value, ) = parameters

        elif self.variable_button_group.checkedButton() is not None:
            # Uncheck the current button
            self.variable_button_group.setExclusive(False)
            self.variable_button_group.checkedButton().setChecked(False)
            self.variable_button_group.setExclusive(True)
            assert self.variable_button_group.checkedButton() is None

        # Update variable methods GUI enabled state based on selection.
        for method in Method:
            # use a default constructed imputer to query support
            imputer = self.create_imputer(method)
            enabled = all(
                imputer.supports_variable(var) for var in selected_vars)
            button = self.variable_button_group.button(method)
            button.setEnabled(enabled)

        # Update the "Value" edit GUI.
        if not has_discrete:
            # no discrete variables -> allow mass edit for all (continuous vars)
            value_stack_enabled = True
            current_value_widget = self.value_double
        elif len(selected_vars) == 1:
            # single discrete var -> enable and fill the values combo
            value_stack_enabled = True
            current_value_widget = self.value_combo
            self.value_combo.clear()
            self.value_combo.addItems(selected_vars[0].values)
        else:
            # mixed type selection -> disable
            value_stack_enabled = False
            current_value_widget = None
            self.variable_button_group.button(Method.Default).setEnabled(False)

        self.value_stack.setEnabled(value_stack_enabled)
        if current_value_widget is not None:
            self.value_stack.setCurrentWidget(current_value_widget)
            if fixed_value is not None:
                # set current value
                if current_value_widget is self.value_combo:
                    self.value_combo.setCurrentIndex(fixed_value)
                elif current_value_widget is self.value_double:
                    self.value_double.setValue(fixed_value)
                else:
                    assert False

    def set_method_for_current_selection(self, method_index):
        # type: (Method) -> None
        indexes = self.selection.selectedIndexes()
        self.set_method_for_indexes(indexes, method_index)

    def set_method_for_indexes(self, indexes, method_index):
        # type: (List[QModelIndex], Method) -> None
        if method_index == Method.AsAboveSoBelow:
            for index in indexes:
                self.varmodel.setData(index, None, StateRole)
        elif method_index == Method.Default:
            current = self.value_stack.currentWidget()
            if current is self.value_combo:
                value = self.value_combo.currentIndex()
            else:
                value = self.value_double.value()
            for index in indexes:
                state = (int(Method.Default), (value, ))
                self.varmodel.setData(index, state, StateRole)
        else:
            state = (int(method_index), ())
            for index in indexes:
                self.varmodel.setData(index, state, StateRole)

        self.update_varview(indexes)
        self._invalidate()

    def update_varview(self, indexes=None):
        if indexes is None:
            indexes = map(self.varmodel.index, range(len(self.varmodel)))

        for index in indexes:
            self.varmodel.setData(index,
                                  self.get_method_for_column(index.row()),
                                  DisplayMethodRole)

    def _on_value_selected(self):
        # The fixed 'Value' in the widget has been changed by the user.
        self.variable_button_group.button(Method.Default).setChecked(True)
        self.set_method_for_current_selection(Method.Default)

    def reset_variable_state(self):
        indexes = list(map(self.varmodel.index, range(len(self.varmodel))))
        self.set_method_for_indexes(indexes, Method.AsAboveSoBelow)
        self.variable_button_group.button(
            Method.AsAboveSoBelow).setChecked(True)

    def _store_state(self):
        # type: () -> VariableState
        """
        Save the current variable imputation state
        """
        state = {}  # type: VariableState
        for i, var in enumerate(self.varmodel):
            index = self.varmodel.index(i)
            m = index.data(StateRole)
            if m is not None:
                state[var_key(var)] = m
        return state

    def _restore_state(self, state):
        # type: (VariableState) -> None
        """
        Restore the variable imputation state from the saved state
        """
        def check(state):
            # check if state is a proper State
            if isinstance(state, tuple) and len(state) == 2:
                m, p = state
                if isinstance(m, int) and isinstance(p, tuple) and \
                        0 <= m < len(Method):
                    return True
            return False

        for i, var in enumerate(self.varmodel):
            m = state.get(var_key(var), None)
            if check(m):
                self.varmodel.setData(self.varmodel.index(i), m, StateRole)

    def storeSpecificSettings(self):
        self._variable_imputation_state = self._store_state()
        super().storeSpecificSettings()
示例#18
0
class OWHeatMap(widget.OWWidget):
    name = "Heat map"
    description = "Draw a two dimentional density plot."
    icon = "icons/Heatmap.svg"
    priority = 100

    inputs = [("Data", Orange.data.Table, "set_data")]

    settingsHandler = settings.DomainContextHandler()

    x_var_index = settings.Setting(0)
    y_var_index = settings.Setting(1)
    z_var_index = settings.Setting(0)
    selected_z_values = settings.Setting([])
    color_scale = settings.Setting(1)
    sample_level = settings.Setting(0)

    sample_percentages = []
    sample_percentages_captions = []
    sample_times = [0.1, 0.5, 3, 5, 20, 40, 80]
    sample_times_captions = ['0.1s', '1s', '5s', '10s', '30s', '1min', '2min']

    use_cache = settings.Setting(True)

    n_bins = 2 ** 4

    mouse_mode = 0

    def __init__(self, parent=None):
        super().__init__(self, parent)

        self.dataset = None
        self.z_values = []

        self._root = None
        self._displayed_root = None
        self._item = None
        self._cache = {}

        self.colors = colorpalette.ColorPaletteGenerator(10)

        self.sampling_box = box = gui.widgetBox(self.controlArea, "Sampling")
        sampling_options =\
            self.sample_times_captions + self.sample_percentages_captions
        gui.comboBox(box, self, 'sample_level',
                     items=sampling_options,
                     callback=self.update_sample)

        gui.button(box, self, "Sharpen", self.sharpen)

        box = gui.widgetBox(self.controlArea, "Input")

        self.labelDataInput = gui.widgetLabel(box, 'No data on input')
        self.labelDataInput.setTextFormat(Qt.PlainText)
        self.labelOutput = gui.widgetLabel(box, '')

        self.x_var_model = itemmodels.VariableListModel()
        self.comboBoxAttributesX = gui.comboBox(
            self.controlArea, self, value='x_var_index', box='X Attribute',
            callback=self.replot)
        self.comboBoxAttributesX.setModel(self.x_var_model)

        self.y_var_model = itemmodels.VariableListModel()
        self.comboBoxAttributesY = gui.comboBox(
            self.controlArea, self, value='y_var_index', box='Y Attribute',
            callback=self.replot)
        self.comboBoxAttributesY.setModel(self.y_var_model)

        box = gui.widgetBox(self.controlArea, "Color by")
        self.z_var_model = itemmodels.VariableListModel()
        self.comboBoxClassvars = gui.comboBox(
            box, self, value='z_var_index',
            callback=self._on_z_var_changed)
        self.comboBoxClassvars.setModel(self.z_var_model)

        box1 = gui.widgetBox(box, 'Colors displayed', margin=0)
        box1.setFlat(True)

        self.z_values_view = gui.listBox(
            box1, self, "selected_z_values", "z_values",
            callback=self._on_z_values_selection_changed,
            selectionMode=QtGui.QListView.MultiSelection,
            addSpace=False
        )
        box1 = gui.widgetBox(box, "Color Scale", margin=0)
        box1.setFlat(True)
        gui.comboBox(box1, self, "color_scale",
                     items=["Linear", "Square root", "Logarithmic"],
                     callback=self._on_color_scale_changed)

        self.mouseBehaviourBox = gui.radioButtons(
            self.controlArea, self, value='mouse_mode',
            btnLabels=('Drag', 'Select'),
            box='Mouse left button behavior',
            callback=self._update_mouse_mode
        )

        gui.rubber(self.controlArea)

        self.plot = pg.PlotWidget(background="w")
        self.plot.setMenuEnabled(False)
        self.plot.setFrameStyle(QtGui.QFrame.StyledPanel)
        self.plot.setMinimumSize(500, 500)

        def font_resize(font, factor, minsize=None, maxsize=None):
            font = QtGui.QFont(font)
            fontinfo = QtGui.QFontInfo(font)
            size = fontinfo.pointSizeF() * factor

            if minsize is not None:
                size = max(size, minsize)
            if maxsize is not None:
                size = min(size, maxsize)

            font.setPointSizeF(size)
            return font

        axisfont = font_resize(self.font(), 0.8, minsize=11)
        axispen = QtGui.QPen(self.palette().color(QtGui.QPalette.Text))
        axis = self.plot.getAxis("bottom")
        axis.setTickFont(axisfont)
        axis.setPen(axispen)
        axis = self.plot.getAxis("left")
        axis.setTickFont(axisfont)
        axis.setPen(axispen)

        self.plot.getViewBox().sigTransformChanged.connect(
            self._on_transform_changed)
        self.mainArea.layout().addWidget(self.plot)

    def set_data(self, dataset):
        self.closeContext()
        self.clear()

        if isinstance(dataset, SqlTable):
            self.original_data = dataset
            self.sample_level = 0
            self.sampling_box.setVisible(True)

            self.update_sample()
        else:
            self.dataset = dataset
            self.sampling_box.setVisible(False)
            self.set_sampled_data(self.dataset)

    def update_sample(self):
        self.clear()

        if self.sample_level < len(self.sample_times):
            sample_type = 'time'
            level = self.sample_times[self.sample_level]
        else:
            sample_type = 'percentage'
            level = self.sample_level - len(self.sample_times)
            level = self.sample_percentages[level]

        if sample_type == 'time':
            self.dataset = \
                self.original_data.sample_time(level, no_cache=True)
        else:
            if 0 < level < 100:
                self.dataset = \
                    self.original_data.sample_percentage(level, no_cache=True)
            if level >= 100:
                self.dataset = self.original_data
        self.set_sampled_data(self.dataset)

    def set_sampled_data(self, dataset):
        if dataset is not None:
            domain = dataset.domain
            cvars = list(filter(is_continuous, domain.variables))
            dvars = list(filter(is_discrete, domain.variables))

            self.x_var_model[:] = cvars
            self.y_var_model[:] = cvars
            self.z_var_model[:] = dvars

            nvars = len(cvars)
            self.x_var_index = min(max(0, self.x_var_index), nvars - 1)
            self.y_var_index = min(max(0, self.y_var_index), nvars - 1)
            self.z_var_index = min(max(0, self.z_var_index), len(cvars) - 1)

            if is_discrete(domain.class_var):
                self.z_var_index = dvars.index(domain.class_var)
            else:
                self.z_var_index = len(dvars) - 1

            self.openContext(dataset)

            if 0 <= self.z_var_index < len(self.z_var_model):
                self.z_values = self.z_var_model[self.z_var_index].values
                k = len(self.z_values)
                self.selected_z_values = range(k)
                self.colors = colorpalette.ColorPaletteGenerator(k)
                for i in range(k):
                    item = self.z_values_view.item(i)
                    item.setIcon(colorpalette.ColorPixmap(self.colors[i]))

            self.labelDataInput.setText(
                'Data set: %s'
                % (getattr(self.dataset, "name", "untitled"),)
            )

            self.setup_plot()
        else:
            self.labelDataInput.setText('No data on input')
            self.send("Sampled data", None)

    def clear(self):
        self.dataset = None
        self.x_var_model[:] = []
        self.y_var_model[:] = []
        self.z_var_model[:] = []
        self.z_values = []
        self._root = None
        self._displayed_root = None
        self._item = None
        self._cache = {}
        self.plot.clear()

    def _on_z_var_changed(self):
        if 0 <= self.z_var_index < len(self.z_var_model):
            self.z_values = self.z_var_model[self.z_var_index].values
            k = len(self.z_values)
            self.selected_z_values = range(k)

            self.colors = colorpalette.ColorPaletteGenerator(k)
            for i in range(k):
                item = self.z_values_view.item(i)
                item.setIcon(colorpalette.ColorPixmap(self.colors[i]))

            self.replot()

    def _on_z_values_selection_changed(self):
        if self._displayed_root is not None:
            self.update_map(self._displayed_root)

    def _on_color_scale_changed(self):
        if self._displayed_root is not None:
            self.update_map(self._displayed_root)

    def setup_plot(self):
        """Setup the density map plot"""
        self.plot.clear()
        if self.dataset is None or self.x_var_index == -1 or \
                self.y_var_index == -1:
            return

        data = self.dataset
        xvar = self.x_var_model[self.x_var_index]
        yvar = self.y_var_model[self.y_var_index]
        if 0 <= self.z_var_index < len(self.z_var_model):
            zvar = self.z_var_model[self.z_var_index]
        else:
            zvar = None

        axis = self.plot.getAxis("bottom")
        axis.setLabel(xvar.name)

        axis = self.plot.getAxis("left")
        axis.setLabel(yvar.name)

        if (xvar, yvar, zvar) in self._cache:
            root = self._cache[xvar, yvar, zvar]
        else:
            root = self.get_root(data, xvar, yvar, zvar)
            self._cache[xvar, yvar, zvar] = root

        self._root = root

        self.update_map(root)

    def get_root(self, data, xvar, yvar, zvar=None):
        """Compute the root density map item"""
        assert self.n_bins > 2
        x_disc = EqualWidth(n=self.n_bins)(data, xvar)
        y_disc = EqualWidth(n=self.n_bins)(data, yvar)

        def bins(var):
            points = list(var.compute_value.points)
            assert points[0] <= points[1]
            width = points[1] - points[0]
            return np.array([points[0] - width] +
                            points +
                            [points[-1] + width])

        xbins = bins(x_disc)
        ybins = bins(y_disc)

        # Extend the lower/upper bin edges to infinity.
        # (the grid_bin function has an optimization for this case).
        xbins1 = np.r_[-np.inf, xbins[1:-1], np.inf]
        ybins1 = np.r_[-np.inf, ybins[1:-1], np.inf]

        t = grid_bin(data, xvar, yvar, xbins1, ybins1, zvar=zvar)
        return t._replace(xbins=xbins, ybins=ybins)

    def replot(self):
        self.plot.clear()
        self.setup_plot()

    def update_map(self, root):
        self.plot.clear()
        self._item = None

        self._displayed_root = root

        palette = self.colors
        contingencies = root.contingencies

        def Tree_take(node, indices, axis):
            """Take elements from the contingency matrices in node."""
            contig = np.take(node.contingencies, indices, axis)
            if node.is_leaf:
                return node._replace(contingencies=contig)
            else:
                children_ar = np.full(node.children.size, None, dtype=object)
                children_ar[:] = [
                    Tree_take(ch, indices, axis) if ch is not None else None
                    for ch in node.children.flat
                ]
                children = children_ar.reshape(node.children.shape)
                return node._replace(contingencies=contig, children=children)

        if contingencies.ndim == 3:
            if not self.selected_z_values:
                return

            _, _, k = contingencies.shape

            if self.selected_z_values != list(range(k)):
                palette = [palette[i] for i in self.selected_z_values]
                root = Tree_take(root, self.selected_z_values, 2)

        self._item = item = DensityPatch(
            root, cell_size=10,
            cell_shape=DensityPatch.Rect,
            color_scale=self.color_scale + 1,
            palette=palette
        )
        self.plot.addItem(item)

    def sharpen(self):
        viewb = self.plot.getViewBox()
        rect = viewb.boundingRect()
        p1 = viewb.mapToView(rect.topLeft())
        p2 = viewb.mapToView(rect.bottomRight())
        rect = QtCore.QRectF(p1, p2).normalized()

        self.sharpen_region(rect)

    def sharpen_root_region(self, region):
        data = self.dataset
        xvar = self.x_var_model[self.x_var_index]
        yvar = self.y_var_model[self.y_var_index]

        if 0 <= self.z_var_index < len(self.z_var_model):
            zvar = self.z_var_model[self.z_var_index]
        else:
            zvar = None

        root = self._root

        if not QRectF(*root.brect).intersects(region):
            return

        nbins = self.n_bins

        def bin_func(xbins, ybins):
            return grid_bin(data, xvar, yvar, xbins, ybins, zvar)

        self.progressBarInit()
        last_node = root
        update_time = time.time()
        changed = False

        for i, node in enumerate(
                sharpen_region(self._root, region, nbins, bin_func)):
            tick = time.time() - update_time
            changed = changed or node is not last_node
            if changed and ((i % nbins == 0) or tick > 2.0):
                self.update_map(node)
                last_node = node
                changed = False
                update_time = time.time()
                self.progressBarSet(100 * i / (nbins ** 2))

        self._root = last_node
        self._cache[xvar, yvar, zvar] = self._root
        self.update_map(self._root)
        self.progressBarFinished()

    def _sampling_width(self):
        if self._item is None:
            return 0

        item = self._item
        rect = item.rect()

        T = self.plot.transform() * item.sceneTransform()
#         lod = QtGui.QStyleOptionGraphicsItem.levelOfDetailFromTransform(T)
        lod = lod_from_transform(T)
        size1 = np.sqrt(rect.width() * rect.height()) / self.n_bins
        cell_size = 10
        scale = cell_size / (lod * size1)
        if np.isinf(scale):
            scale = np.finfo(float).max
        p = int(np.floor(np.log2(scale)))
        p = min(p, int(np.log2(self.n_bins)))
        return 2 ** int(p)

    def sharpen_region(self, region):
        data = self.dataset
        xvar = self.x_var_model[self.x_var_index]
        yvar = self.y_var_model[self.y_var_index]

        if 0 <= self.z_var_index < len(self.z_var_model):
            zvar = self.z_var_model[self.z_var_index]
        else:
            zvar = None

        root = self._root
        nbins = self.n_bins

        if not QRectF(*root.brect).intersects(region):
            return

        def bin_func(xbins, ybins):
            return grid_bin(data, xvar, yvar, xbins, ybins, zvar)

        def min_depth(node, region):
            if not region.intersects(QRectF(*node.brect)):
                return np.inf
            elif node.is_leaf:
                return 1
            elif node.is_empty:
                return 1
            else:
                xs, xe, ys, ye = bindices(node, region)
                children = node.children[xs: xe, ys: ye].ravel()
                contingency = node.contingencies[xs: xe, ys: ye]
                if contingency.ndim == 3:
                    contingency = contingency.reshape(-1, contingency.shape[2])

                if any(ch is None and np.any(val)
                       for ch, val in zip(children, contingency)):
                    return 1
                else:
                    ch_depth = [min_depth(ch, region) + 1
                                for ch in filter(is_not_none, children.flat)]
                    return min(ch_depth if ch_depth else [1])

        depth = min_depth(self._root, region)
        bw = self._sampling_width()
        nodes = self.select_nodes_to_sharpen(self._root, region, bw,
                                             depth + 1)

        def update_rects(node):
            scored = score_candidate_rects(node, region)
            ind1 = set(zip(*Node_nonzero(node)))
            ind2 = set(zip(*node.children.nonzero())) \
                   if not node.is_leaf else set()
            ind = ind1 - ind2
            return [(score, r) for score, i, j, r in scored if (i, j) in ind]

        scored_rects = reduce(operator.iadd, map(update_rects, nodes), [])
        scored_rects = sorted(scored_rects, reverse=True,
                              key=operator.itemgetter(0))
        root = self._root
        self.progressBarInit()
        update_time = time.time()

        for i, (_, rect) in enumerate(scored_rects):
            root = sharpen_region_recur(
                root, rect.intersect(region),
                nbins, depth + 1, bin_func
            )
            tick = time.time() - update_time
            if tick > 2.0:
                self.update_map(root)
                update_time = time.time()

            self.progressBarSet(100 * i / len(scored_rects))

        self._root = root

        self._cache[xvar, yvar, zvar] = self._root
        self.update_map(self._root)
        self.progressBarFinished()

    def select_nodes_to_sharpen(self, node, region, bw, depth):
        """
        :param node:
        :param bw: bandwidth (samplewidth)
        :param depth: maximum node depth to consider
        """

        if not QRectF(*node.brect).intersects(region):
            return []
        elif bw >= 1:
            return []
        elif depth == 1:
            return []
        elif node.is_empty:
            return []
        elif node.is_leaf:
            return [node]
        else:
            xs, xe, ys, ye = bindices(node, region)

            def intersect_indices(rows, cols):
                mask = (xs <= rows) & (rows < xe) & (ys <= cols) & (cols < ye)
                return rows[mask], cols[mask]

            indices1 = intersect_indices(*Node_nonzero(node))
            indices2 = intersect_indices(*node.children.nonzero())
            # If there are any non empty and non expanded cells in the
            # intersection return the node for sharpening, ...
            if np.any(np.array(indices1) != np.array(indices2)):
                return [node]

            children = node.children[indices2]
            # ... else run down the children in the intersection
            return reduce(operator.iadd,
                          (self.select_nodes_to_sharpen(
                               ch, region, bw * node.nbins, depth - 1)
                           for ch in children.flat),
                          [])

    def _update_mouse_mode(self):
        if self.mouse_mode == 0:
            mode = pg.ViewBox.PanMode
        else:
            mode = pg.ViewBox.RectMode
        self.plot.getViewBox().setMouseMode(mode)

    def _on_transform_changed(self, *args):
        pass

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()
示例#19
0
class OWImageViewer(widget.OWWidget):
    name = "Image Viewer"
    description = "Views images embedded in the data."
    icon = "icons/ImageViewer.svg"
    priority = 4050

    inputs = [("Data", Orange.data.Table, "setData")]
    outputs = [(
        "Data",
        Orange.data.Table,
    )]

    settingsHandler = settings.DomainContextHandler()

    imageAttr = settings.ContextSetting(0)
    titleAttr = settings.ContextSetting(0)

    zoom = settings.Setting(25)
    autoCommit = settings.Setting(False)

    show_save_graph = True
    want_graph = True

    def __init__(self, parent=None):
        super().__init__(parent)

        self.info = gui.widgetLabel(gui.widgetBox(self.controlArea, "Info"),
                                    "Waiting for input\n")

        self.imageAttrCB = gui.comboBox(
            self.controlArea,
            self,
            "imageAttr",
            box="Image Filename Attribute",
            tooltip="Attribute with image filenames",
            callback=[self.clearScene, self.setupScene],
            addSpace=True)

        self.titleAttrCB = gui.comboBox(self.controlArea,
                                        self,
                                        "titleAttr",
                                        box="Title Attribute",
                                        tooltip="Attribute with image title",
                                        callback=self.updateTitles,
                                        addSpace=True)

        gui.hSlider(self.controlArea,
                    self,
                    "zoom",
                    box="Zoom",
                    minValue=1,
                    maxValue=100,
                    step=1,
                    callback=self.updateZoom,
                    createLabel=False)

        gui.separator(self.controlArea)
        gui.auto_commit(self.controlArea, self, "autoCommit", "Commit",
                        "Auto commit")

        gui.rubber(self.controlArea)

        self.scene = GraphicsScene()
        self.sceneView = QGraphicsView(self.scene, self)
        self.sceneView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.sceneView.setRenderHint(QPainter.Antialiasing, True)
        self.sceneView.setRenderHint(QPainter.TextAntialiasing, True)
        self.sceneView.setFocusPolicy(Qt.WheelFocus)
        self.sceneView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.sceneView.installEventFilter(self)
        self.mainArea.layout().addWidget(self.sceneView)

        self.scene.selectionChanged.connect(self.onSelectionChanged)
        self.scene.selectionRectPointChanged.connect(
            self.onSelectionRectPointChanged, Qt.QueuedConnection)
        self.graphButton.clicked.connect(self.saveScene)
        self.resize(800, 600)

        self.thumbnailWidget = None
        self.sceneLayout = None
        self.selectedIndices = []

        #: List of _ImageItems
        self.items = []

        self._errcount = 0
        self._successcount = 0

        self.loader = ImageLoader(self)

    def setData(self, data):
        self.closeContext()
        self.clear()

        self.data = data

        if data is not None:
            domain = data.domain
            self.allAttrs = domain.variables + domain.metas
            self.stringAttrs = list(filter(is_string, self.allAttrs))

            self.stringAttrs = sorted(self.stringAttrs,
                                      key=lambda attr: 0
                                      if "type" in attr.attributes else 1)

            indices = [
                i for i, var in enumerate(self.stringAttrs)
                if var.attributes.get("type") == "image"
            ]
            if indices:
                self.imageAttr = indices[0]

            self.imageAttrCB.setModel(VariableListModel(self.stringAttrs))
            self.titleAttrCB.setModel(VariableListModel(self.allAttrs))

            self.openContext(data)

            self.imageAttr = max(
                min(self.imageAttr,
                    len(self.stringAttrs) - 1), 0)
            self.titleAttr = max(min(self.titleAttr,
                                     len(self.allAttrs) - 1), 0)

            if self.stringAttrs:
                self.setupScene()
        else:
            self.info.setText("Waiting for input\n")

    def clear(self):
        self.data = None
        self.information(0)
        self.error(0)
        self.imageAttrCB.clear()
        self.titleAttrCB.clear()
        self.clearScene()

    def setupScene(self):
        self.information(0)
        self.error(0)
        if self.data:
            attr = self.stringAttrs[self.imageAttr]
            titleAttr = self.allAttrs[self.titleAttr]
            instances = [
                inst for inst in self.data if numpy.isfinite(inst[attr])
            ]
            widget = ThumbnailWidget()
            layout = widget.layout()

            self.scene.addItem(widget)

            for i, inst in enumerate(instances):
                url = self.urlFromValue(inst[attr])
                title = str(inst[titleAttr])

                thumbnail = GraphicsThumbnailWidget(QPixmap(),
                                                    title=title,
                                                    parent=widget)

                thumbnail.setToolTip(url.toString())
                thumbnail.instance = inst
                layout.addItem(thumbnail, i / 5, i % 5)

                if url.isValid():
                    future = self.loader.get(url)
                    watcher = _FutureWatcher(parent=thumbnail)

                    #                     watcher = FutureWatcher(future, parent=thumbnail)

                    def set_pixmap(thumb=thumbnail, future=future):
                        if future.cancelled():
                            return
                        if future.exception():
                            # Should be some generic error image.
                            pixmap = QPixmap()
                            thumb.setToolTip(thumb.toolTip() + "\n" +
                                             str(future.exception()))
                        else:
                            pixmap = QPixmap.fromImage(future.result())

                        thumb.setPixmap(pixmap)
                        if not pixmap.isNull():
                            thumb.setThumbnailSize(self.pixmapSize(pixmap))

                        self._updateStatus(future)

                    watcher.finished.connect(set_pixmap, Qt.QueuedConnection)
                    watcher.setFuture(future)
                else:
                    future = None
                self.items.append(_ImageItem(i, thumbnail, url, future))

            widget.show()
            widget.geometryChanged.connect(self._updateSceneRect)

            self.info.setText("Retrieving...\n")
            self.thumbnailWidget = widget
            self.sceneLayout = layout

        if self.sceneLayout:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())
            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)
            self.sceneLayout.activate()

    def urlFromValue(self, value):
        variable = value.variable
        origin = variable.attributes.get("origin", "")
        if origin and QDir(origin).exists():
            origin = QUrl.fromLocalFile(origin)
        elif origin:
            origin = QUrl(origin)
            if not origin.scheme():
                origin.setScheme("file")
        else:
            origin = QUrl("")
        base = origin.path()
        if base.strip() and not base.endswith("/"):
            origin.setPath(base + "/")

        name = QUrl(str(value))
        url = origin.resolved(name)
        if not url.scheme():
            url.setScheme("file")
        return url

    def pixmapSize(self, pixmap):
        """
        Return the preferred pixmap size based on the current `zoom` value.
        """
        scale = 2 * self.zoom / 100.0
        size = QSizeF(pixmap.size()) * scale
        return size.expandedTo(QSizeF(16, 16))

    def clearScene(self):
        for item in self.items:
            if item.future:
                item.future._reply.close()
                item.future.cancel()

        self.items = []
        self._errcount = 0
        self._successcount = 0

        self.scene.clear()
        self.thumbnailWidget = None
        self.sceneLayout = None

    def thumbnailItems(self):
        return [item.widget for item in self.items]

    def updateZoom(self):
        for item in self.thumbnailItems():
            item.setThumbnailSize(self.pixmapSize(item.pixmap()))

        if self.thumbnailWidget:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())

            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)

        if self.sceneLayout:
            self.sceneLayout.activate()

    def updateTitles(self):
        titleAttr = self.allAttrs[self.titleAttr]
        for item in self.items:
            item.widget.setTitle(str(item.widget.instance[titleAttr]))

    def onSelectionChanged(self):
        selected = [item for item in self.items if item.widget.isSelected()]
        self.selectedIndices = [item.index for item in selected]
        self.commit()

    def onSelectionRectPointChanged(self, point):
        self.sceneView.ensureVisible(QRectF(point, QSizeF(1, 1)), 5, 5)

    def commit(self):
        if self.data:
            if self.selectedIndices:
                selected = self.data[self.selectedIndices]
            else:
                selected = None
            self.send("Data", selected)
        else:
            self.send("Data", None)

    def saveScene(self):
        from OWDlgs import OWChooseImageSizeDlg
        sizeDlg = OWChooseImageSizeDlg(self.scene, parent=self)
        sizeDlg.exec_()

    def _updateStatus(self, future):
        if future.cancelled():
            return

        if future.exception():
            self._errcount += 1
            _log.debug("Error: %r", future.exception())
        else:
            self._successcount += 1

        count = len([item for item in self.items if item.future is not None])
        self.info.setText("Retrieving:\n" +
                          "{} of {} images".format(self._successcount, count))

        if self._errcount + self._successcount == count:
            if self._errcount:
                self.info.setText(
                    "Done:\n" +
                    "{} images, {} errors".format(count, self._errcount))
            else:
                self.info.setText("Done:\n" + "{} images".format(count))
            attr = self.stringAttrs[self.imageAttr]
            if self._errcount == count and not "type" in attr.attributes:
                self.error(
                    0, "No images found! Make sure the '%s' attribute "
                    "is tagged with 'type=image'" % attr.name)

    def _updateSceneRect(self):
        self.scene.setSceneRect(self.scene.itemsBoundingRect())

    def onDeleteWidget(self):
        for item in self.items:
            item.future._reply.abort()
            item.future.cancel()

    def eventFilter(self, receiver, event):
        if receiver is self.sceneView and event.type() == QEvent.Resize \
                and self.thumbnailWidget:
            width = (self.sceneView.width() -
                     self.sceneView.verticalScrollBar().width())

            self.thumbnailWidget.reflow(width)
            self.thumbnailWidget.setPreferredWidth(width)

        return super(OWImageViewer, self).eventFilter(receiver, event)
示例#20
0
class OWDiscretize(widget.OWWidget):
    name = "Discretize"
    description = "Discretization of continuous attributes."
    icon = "icons/Discretize.svg"
    inputs = [{
        "name": "Data",
        "type": Orange.data.Table,
        "handler": "set_data",
        "doc": "Input data table"
    }]

    outputs = [{
        "name": "Data",
        "type": Orange.data.Table,
        "doc": "Table with discretized features"
    }]

    settingsHandler = settings.DomainContextHandler()
    saved_var_states = settings.ContextSetting({})

    default_method = settings.Setting(0)
    default_k = settings.Setting(5)

    # Discretization methods
    Default, Leave, MDL, EqualFreq, EqualWidth, Remove, Custom = range(7)

    want_main_area = False

    def __init__(self, parent=None):
        super().__init__(parent)

        #: input data
        self.data = None
        #: Current variable discretization state
        self.var_state = {}
        #: Saved variable discretization settings (context setting)
        self.saved_var_states = {}

        self.method = 0
        self.k = 5

        box = gui.widgetBox(self.controlArea,
                            self.tr("Default Discretization"))
        self.default_bbox = rbox = gui.radioButtons(
            box, self, "default_method", callback=self._default_disc_changed)

        options = [
            self.tr("Default"),
            self.tr("Leave continuous"),
            self.tr("Entropy-MDL discretization"),
            self.tr("Equal-frequency discretization"),
            self.tr("Equal-width discretization"),
            self.tr("Remove continuous attributes")
        ]

        for opt in options[1:5]:
            gui.appendRadioButton(rbox, opt)

        gui.hSlider(gui.indentedBox(rbox),
                    self,
                    "default_k",
                    minValue=2,
                    maxValue=10,
                    label="Num. of intervals:",
                    callback=self._default_disc_changed)

        gui.appendRadioButton(rbox, options[-1])

        vlayout = QHBoxLayout()
        box = gui.widgetBox(self.controlArea,
                            "Individual Attribute Settings",
                            orientation=vlayout,
                            spacing=8)

        # List view with all attributes
        self.varview = QListView(selectionMode=QListView.ExtendedSelection)
        self.varview.setItemDelegate(DiscDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._var_selection_changed)

        vlayout.addWidget(self.varview)
        # Controls for individual attr settings
        self.bbox = controlbox = gui.radioButtons(
            box, self, "method", callback=self._disc_method_changed)
        vlayout.addWidget(controlbox)

        for opt in options[:5]:
            gui.appendRadioButton(controlbox, opt)

        gui.hSlider(gui.indentedBox(controlbox),
                    self,
                    "k",
                    minValue=2,
                    maxValue=10,
                    label="Num. of intervals:",
                    callback=self._disc_method_changed)

        gui.appendRadioButton(controlbox, options[-1])

        gui.rubber(controlbox)
        controlbox.setEnabled(False)

        self.controlbox = controlbox

        bbox = QDialogButtonBox(QDialogButtonBox.Apply)
        self.controlArea.layout().addWidget(bbox)
        bbox.accepted.connect(self.commit)
        button = bbox.button(QDialogButtonBox.Apply)
        button.clicked.connect(self.commit)

    def set_data(self, data):
        self.closeContext()
        self.data = data
        if self.data is not None:
            self._initialize(data)
            self.openContext(data)
            # Restore the per variable discretization settings
            self._restore(self.saved_var_states)
            # Complete the induction of cut points
            self._update_points()
        else:
            self._clear()

        self.commit()

    def _initialize(self, data):
        # Initialize the default variable states for new data.
        self.class_var = data.domain.class_var
        cvars = [
            var for var in data.domain
            if isinstance(var, Orange.data.ContinuousVariable)
        ]
        self.varmodel[:] = cvars

        class_var = data.domain.class_var
        has_disc_class = isinstance(class_var, Orange.data.DiscreteVariable)

        self.default_bbox.buttons[self.MDL - 1].setEnabled(has_disc_class)
        self.bbox.buttons[self.MDL].setEnabled(has_disc_class)

        # If the newly disabled MDL button is checked then change it
        if not has_disc_class and self.default_method == self.MDL - 1:
            self.default_method = 0
        if not has_disc_class and self.method == self.MDL:
            self.method = 0

        # Reset (initialize) the variable discretization states.
        self._reset()

    def _restore(self, saved_state):
        # Restore variable states from a saved_state dictionary.
        for i, var in enumerate(self.varmodel):
            key = variable_key(var)
            if key in saved_state:
                state = saved_state[key]
                self._set_var_state(i, state)

    def _reset(self):
        # restore the individual variable settings back to defaults.
        def_method = self._current_default_method()
        self.var_state = {}
        for i in range(len(self.varmodel)):
            state = DState(Default(def_method), None, None)
            self._set_var_state(i, state)

    def _set_var_state(self, index, state):
        # set the state of variable at `index` to `state`.
        self.var_state[index] = state
        self.varmodel.setData(self.varmodel.index(index), state, Qt.UserRole)

    def _clear(self):
        self.data = None
        self.varmodel[:] = []
        self.var_state = {}
        self.saved_var_states = {}
        self.default_bbox.buttons[self.MDL - 1].setEnabled(True)
        self.bbox.buttons[self.MDL].setEnabled(True)

    def _update_points(self):
        """
        Update the induced cut points.
        """
        def induce_cuts(method, data, var):
            dvar = _dispatch[type(method)](method, data, var)
            if dvar is None:
                # removed
                return [], None
            elif dvar is var:
                # no transformation took place
                return None, var
            elif is_discretized(dvar):
                return dvar.get_value_from.points, dvar
            else:
                assert False

        for i, var in enumerate(self.varmodel):
            state = self.var_state[i]
            if state.points is None and state.disc_var is None:
                points, dvar = induce_cuts(state.method, self.data, var)
                new_state = state._replace(points=points, disc_var=dvar)
                self._set_var_state(i, new_state)

    def _method_index(self, method):
        return METHODS.index((type(method), ))

    def _current_default_method(self):
        method = self.default_method + 1
        k = self.default_k
        if method == OWDiscretize.Leave:
            def_method = Leave()
        elif method == OWDiscretize.MDL:
            def_method = MDL()
        elif method == OWDiscretize.EqualFreq:
            def_method = EqualFreq(k)
        elif method == OWDiscretize.EqualWidth:
            def_method = EqualWidth(k)
        elif method == OWDiscretize.Remove:
            def_method = Remove()
        else:
            assert False
        return def_method

    def _current_method(self):
        if self.method == OWDiscretize.Default:
            method = Default(self._current_default_method())
        elif self.method == OWDiscretize.Leave:
            method = Leave()
        elif self.method == OWDiscretize.MDL:
            method = MDL()
        elif self.method == OWDiscretize.EqualFreq:
            method = EqualFreq(self.k)
        elif self.method == OWDiscretize.EqualWidth:
            method = EqualWidth(self.k)
        elif self.method == OWDiscretize.Remove:
            method = Remove()
        elif self.method == OWDiscretize.Custom:
            method = Custom(self.cutpoints)
        else:
            assert False
        return method

    def _default_disc_changed(self):
        method = self._current_default_method()
        state = DState(Default(method), None, None)
        for i, _ in enumerate(self.varmodel):
            if isinstance(self.var_state[i].method, Default):
                self._set_var_state(i, state)

        self._update_points()

    def _disc_method_changed(self):
        indices = self.selected_indices()
        method = self._current_method()
        state = DState(method, None, None)

        for idx in indices:
            self._set_var_state(idx, state)

        self._update_points()

    def _var_selection_changed(self, *args):
        indices = self.selected_indices()
        # set of all methods for the current selection
        methods = [self.var_state[i].method for i in indices]
        mset = set(methods)
        self.controlbox.setEnabled(len(mset) > 0)
        if len(mset) == 1:
            method = mset.pop()
            self.method = self._method_index(method)

            if isinstance(method, (EqualFreq, EqualWidth)):
                self.k = method.k
            elif isinstance(method, Custom):
                self.cutpoints = method.points

        else:
            # deselect the current button
            self.method = -1
            bg = self.controlbox.group
            button_group_reset(bg)

    def selected_indices(self):
        rows = self.varview.selectionModel().selectedRows()
        return [index.row() for index in rows]

    def discretized_var(self, source):
        index = list(self.varmodel).index(source)
        state = self.var_state[index]
        if state.disc_var is None:
            return None
        elif state.disc_var is source:
            return source
        elif state.points == []:
            return None
        else:
            return state.disc_var

    def discretized_domain(self):
        """
        Return the current effective discretized domain.
        """
        if self.data is None:
            return None

        def disc_var(source):
            if isinstance(source, Orange.data.ContinuousVariable):
                return self.discretized_var(source)
            else:
                return source

        attributes = [disc_var(v) for v in self.data.domain.attributes]
        attributes = [v for v in attributes if v is not None]

        class_var = disc_var(self.data.domain.class_var)

        domain = Orange.data.Domain(attributes,
                                    class_var,
                                    metas=self.data.domain.metas)
        return domain

    def commit(self):
        output = None
        if self.data is not None:
            domain = self.discretized_domain()
            output = self.data.from_table(domain, self.data)
        self.send("Data", output)

    def storeSpecificSettings(self):
        super().storeSpecificSettings()
        self.saved_var_states = {
            variable_key(var): self.var_state[i]._replace(points=None,
                                                          disc_var=None)
            for i, var in enumerate(self.varmodel)
        }
class OWImageGrid(widget.OWWidget):
    name = "Image Grid"
    description = "Visualize images in a similarity grid"
    icon = "icons/ImageGrid.svg"
    priority = 160
    keywords = ["image", "grid", "similarity"]
    graph_name = "scene"

    class Inputs:
        data = Input("Embeddings", Orange.data.Table)
        data_subset = Input("Data Subset", Orange.data.Table)

    class Outputs:
        selected_data = Output(
            "Selected Images", Orange.data.Table, default=True)
        data = Output("Images", Orange.data.Table)

    settingsHandler = settings.DomainContextHandler()

    cell_fit = settings.Setting("Resize")
    columns = settings.Setting(10)
    rows = settings.Setting(10)

    imageAttr = settings.ContextSetting(0)
    imageSize = settings.Setting(100)
    label_attr = settings.ContextSetting(None, required=ContextSetting.OPTIONAL)
    label_selected = settings.Setting(True)

    auto_update = settings.Setting(True)
    auto_commit = settings.Setting(True)

    class Warning(OWWidget.Warning):
        incompatible_subset = Msg("Data subset is incompatible with Data")
        no_valid_data = Msg("No valid data")

    def __init__(self):
        super().__init__()

        self.grid = None

        self.data = None
        self.data_subset = None
        self.subset_indices = []
        self.nonempty = []

        self.allAttrs = []
        self.stringAttrs = []
        self.domainAttrs = []
        self.label_model = DomainModel(placeholder="(No labels)")

        self.selection = None

        #: List of _ImageItems
        self.items = []

        self._errcount = 0
        self._successcount = 0

        self.imageAttrCB = gui.comboBox(
            self.controlArea, self, "imageAttr",
            box="Image Filename Attribute",
            tooltip="Attribute with image filenames",
            callback=self.change_image_attr,
            contentsLength=12,
            addSpace=True,
        )

        # cell fit (resize or crop)
        self.cellFitRB = gui.radioButtons(
            self.controlArea, self, "cell_fit", ["Resize", "Crop"],
            box="Image cell fit", callback=self.set_crop)

        self.gridSizeBox = gui.vBox(self.controlArea, "Grid size")

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10
        )

        self.colSpinner = gui.spin(
            self.gridSizeBox, self, "columns", minv=2, maxv=40,
            callback=self.update_size)
        self.rowSpinner = gui.spin(
            self.gridSizeBox, self, "rows", minv=2, maxv=40,
            callback=self.update_size)

        form.addRow("Columns:", self.colSpinner)
        form.addRow("Rows:", self.rowSpinner)

        gui.separator(self.gridSizeBox, 10)
        self.gridSizeBox.layout().addLayout(form)

        gui.button(
            self.gridSizeBox, self, "Set size automatically",
            callback=self.auto_set_size)

        self.label_box = gui.vBox(self.controlArea, "Labels")

        # labels control
        self.label_attr_cb = gui.comboBox(
            self.label_box, self, "label_attr",
            tooltip="Show labels",
            callback=self.update_size,
            addSpace=True,
            model=self.label_model
        )

        gui.rubber(self.controlArea)

        # auto commit
        self.autoCommitBox = gui.auto_commit(
            self.controlArea, self, "auto_commit", "Apply",
            checkbox_label="Apply automatically")

        self.image_grid = None
        self.cell_fit = 0

        self.thumbnailView = ThumbnailView(
            alignment=Qt.AlignTop | Qt.AlignLeft,
            focusPolicy=Qt.StrongFocus,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff
        )
        self.mainArea.layout().addWidget(self.thumbnailView)
        self.scene = self.thumbnailView.scene()
        self.scene.selectionChanged.connect(self.on_selection_changed)
        self.loader = ImageLoader(self)

    def process(self, size_x=0, size_y=0):
        if self.image_grid:
            self.image_grid.process(size_x, size_y)

    def sizeHint(self):
        return QSize(600, 600)

    # checks the input data for the right meta-attributes and finds images
    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.Warning.no_valid_data.clear()
        self.data = data

        if data is not None:
            domain = data.domain
            self.allAttrs = (domain.class_vars + domain.metas +
                             domain.attributes)

            self.stringAttrs = [a for a in domain.metas if a.is_string]
            self.domainAttrs = len(domain.attributes)

            self.stringAttrs = sorted(
                self.stringAttrs,
                key=lambda attr: 0 if "type" in attr.attributes else 1
            )

            indices = [i for i, var in enumerate(self.stringAttrs)
                       if var.attributes.get("type") == "image"]
            if indices:
                self.imageAttr = indices[0]

            self.imageAttrCB.setModel(VariableListModel(self.stringAttrs))

            # set label combo labels
            self.label_model.set_domain(domain)
            self.openContext(data)
            self.label_attr = self.label_model[0]
            self.imageAttr = max(
                min(self.imageAttr, len(self.stringAttrs) - 1), 0)

            if self.is_valid_data():
                self.image_grid = ImageGrid(data)
                self.setup_scene()
            else:
                self.Warning.no_valid_data()

    @Inputs.data_subset
    def set_data_subset(self, data_subset):
        self.data_subset = data_subset

    def clear(self):
        self.data = None
        self.image_grid = None
        self.error()
        self.imageAttrCB.clear()
        self.label_attr_cb.clear()
        self.clear_scene()

    def is_valid_data(self):
        return self.data and self.stringAttrs and self.domainAttrs

    # loads the images and places them into the viewing area
    def setup_scene(self, process_grid=True):
        self.clear_scene()
        self.error()
        if self.data:
            attr = self.stringAttrs[self.imageAttr]
            assert self.thumbnailView.count() == 0
            size = QSizeF(self.imageSize, self.imageSize)

            if process_grid and self.image_grid:
                self.process()
                self.columns = self.image_grid.size_x
                self.rows = self.image_grid.size_y

            self.thumbnailView.setFixedColumnCount(self.columns)
            self.thumbnailView.setFixedRowCount(self.rows)

            for i, inst in enumerate(self.image_grid.image_list):
                label_text = (str(inst[self.label_attr])
                    if self.label_attr is not None else "")
                if label_text == "?":
                    label_text = ""

                thumbnail = GraphicsThumbnailWidget(
                    QPixmap(), crop=self.cell_fit == 1,
                    add_label=self.label_selected and
                    self.label_attr is not None, text=label_text)
                thumbnail.setThumbnailSize(size)
                thumbnail.instance = inst
                self.thumbnailView.addThumbnail(thumbnail)

                if not np.isfinite(inst[attr]) or inst[attr] == "?":
                    # skip missing
                    future, url = None, None
                else:
                    url = self.url_from_value(inst[attr])
                    thumbnail.setToolTip(url.toString())
                    self.nonempty.append(i)

                    if url.isValid() and url.isLocalFile():
                        reader = QImageReader(url.toLocalFile())
                        image = reader.read()
                        if image.isNull():
                            error = reader.errorString()
                            thumbnail.setToolTip(
                                thumbnail.toolTip() + "\n" + error)

                            self._errcount += 1
                        else:
                            pixmap = QPixmap.fromImage(image)
                            thumbnail.setPixmap(pixmap)
                            self._successcount += 1

                        future = Future()
                        future.set_result(image)
                        future._reply = None
                    elif url.isValid():
                        future = self.loader.get(url)

                        @future.add_done_callback
                        def set_pixmap(future, thumb=thumbnail):
                            if future.cancelled():
                                return

                            assert future.done()

                            if future.exception():
                                # Should be some generic error image.
                                pixmap = QPixmap()
                                thumb.setToolTip(thumb.toolTip() + "\n" +
                                                 str(future.exception()))
                            else:
                                pixmap = QPixmap.fromImage(future.result())

                            thumb.setPixmap(pixmap)

                            self._note_completed(future)
                    else:
                        future = None

                self.items.append(_ImageItem(i, thumbnail, url, future))

            if not any(
                    not it.future.done() if it.future
                    else False for it in self.items):
                self._update_status()
                self.apply_subset()
                self.update_selection()

    def handleNewSignals(self):
        self.Warning.incompatible_subset.clear()
        self.subset_indices = []

        if self.data and self.data_subset:
            transformed = self.data_subset.transform(self.data.domain)
            if np.all(self.data.domain.metas == self.data_subset.domain.metas):
                indices = {e.id for e in transformed}
                self.subset_indices = [ex.id in indices for ex in self.data]

            else:
                self.Warning.incompatible_subset()

        self.apply_subset()

    def url_from_value(self, value):
        base = value.variable.attributes.get("origin", "")
        if QDir(base).exists():
            base = QUrl.fromLocalFile(base)
        else:
            base = QUrl(base)

        path = base.path()
        if path.strip() and not path.endswith("/"):
            base.setPath(path + "/")

        url = base.resolved(QUrl(str(value)))
        return url

    def cancel_all_futures(self):
        for item in self.items:
            if item.future is not None:
                item.future.cancel()
                if item.future._reply is not None:
                    item.future._reply.close()
                    item.future._reply.deleteLater()
                    item.future._reply = None

    def clear_scene(self):
        self.cancel_all_futures()
        self.items = []
        self.nonempty = []
        self.selection = None
        self.thumbnailView.clear()
        self._errcount = 0
        self._successcount = 0

    def change_image_attr(self):
        self.clear_scene()
        if self.is_valid_data():
            self.setup_scene()

    def thumbnail_items(self):
        return [item.widget for item in self.items]

    def update_size(self):
        try:
            self.process(self.columns, self.rows)
            self.colSpinner.setMinimum(2)
            self.rowSpinner.setMinimum(2)

        except AssertionError:
            grid_size = self.thumbnailView.grid_size()
            self.columns = grid_size[0]
            self.rows = grid_size[1]
            self.colSpinner.setMinimum(self.columns)
            self.rowSpinner.setMinimum(self.rows)
            return

        self.clear_scene()
        if self.is_valid_data():
            self.setup_scene(process_grid=False)

    def set_crop(self):
        self.thumbnailView.setCrop(self.cell_fit == 1)

    def auto_set_size(self):
        self.clear_scene()
        if self.is_valid_data():
            self.setup_scene()

    def apply_subset(self):
        if self.image_grid:
            subset_indices = (self.subset_indices if self.subset_indices
                else [True] * len(self.items))
            ordered_subset_indices = self.image_grid.order_to_grid(
                subset_indices)

            for item, in_subset in zip(self.items, ordered_subset_indices):
                item.widget.setSubset(in_subset)

    def on_selection_changed(self, selected_items, keys):
        if self.selection is None:
            self.selection = np.zeros(len(self.items), dtype=np.uint8)

        # newly selected
        indices = [item.index for item in self.items
                   if item.widget in selected_items]

        # Remove from selection
        if keys & Qt.AltModifier:
            self.selection[indices] = 0
        # Append to the last group
        elif keys & Qt.ShiftModifier and keys & Qt.ControlModifier:
            self.selection[indices] = np.max(self.selection)
        # Create a new group
        elif keys & Qt.ShiftModifier:
            self.selection[indices] = np.max(self.selection) + 1
        # No modifiers: new selection
        else:
            self.selection = np.zeros(len(self.items), dtype=np.uint8)
            self.selection[indices] = 1

        self.update_selection()
        self.commit()

    def commit(self):
        if self.data:
            # add Group column (group number)
            self.Outputs.selected_data.send(
                create_groups_table(self.image_grid.image_list, self.selection,
                                    False, "Group"))

            # filter out empty cells - keep indices of cells that contain images
            # add Selected column
            # (Yes/No if one group, else Unselected or group number)
            if self.selection is not None and np.max(self.selection) > 1:
                out_data = create_groups_table(
                    self.image_grid.image_list[self.nonempty],
                    self.selection[self.nonempty])
            else:
                out_data = create_annotated_table(
                    self.image_grid.image_list[self.nonempty],
                    np.nonzero(self.selection[self.nonempty]))
            self.Outputs.data.send(out_data)

        else:
            self.Outputs.data.send(None)
            self.Outputs.selected_data.send(None)

    def update_selection(self):
        if self.selection is not None:
            pen, brush = self.compute_colors()

            for s, item, p, b in zip(self.selection, self.items, pen, brush):
                item.widget.setSelected(s > 0)
                item.widget.setSelectionColor(p, b)

    # Adapted from Scatter Plot Graph (change brush instead of pen)
    def compute_colors(self):
        no_brush = DEFAULT_SELECTION_BRUSH
        sels = np.max(self.selection)
        if sels == 1:
            brushes = [no_brush, no_brush]
        else:
            palette = ColorPaletteGenerator(number_of_colors=sels + 1)
            brushes = [no_brush] + [QBrush(palette[i]) for i in range(sels)]
        brush = [brushes[a] for a in self.selection]

        pen = [DEFAULT_SELECTION_PEN] * len(self.items)
        return pen, brush

    def send_report(self):
        if self.is_valid_data():
            items = [("Number of images", len(self.data))]
            self.report_items(items)
            self.report_plot("Grid", self.scene)

    def _note_completed(self, future):
        # Note the completed future's state
        if future.cancelled():
            return

        if future.exception():
            self._errcount += 1
            _log.debug("Error: %r", future.exception())
        else:
            self._successcount += 1

        self._update_status()

    def _update_status(self):
        count = len([item for item in self.items if item.future is not None])

        if self._errcount + self._successcount == count:
            attr = self.stringAttrs[self.imageAttr]
            if self._errcount == count and "type" not in attr.attributes:
                self.error("No images found! Make sure the '%s' attribute "
                           "is tagged with 'type=image'" % attr.name)

    def onDeleteWidget(self):
        self.cancel_all_futures()
        self.clear()
示例#22
0
class OWImpute(OWWidget):
    name = "Impute"
    description = "Impute missing values in the data table."
    icon = "icons/Impute.svg"
    priority = 2130

    inputs = [("Data", Orange.data.Table, "set_data"),
              ("Learner", Orange.classification.Learner, "set_learner")]
    outputs = [("Data", Orange.data.Table)]

    METHODS = METHODS

    settingsHandler = settings.DomainContextHandler()

    default_method = settings.Setting(1)
    variable_methods = settings.ContextSetting({})
    autocommit = settings.Setting(True)

    want_main_area = False

    def __init__(self, parent=None):
        super().__init__(parent)
        self.modified = False

        box = group_box(self.tr("Default method"), layout=layout(Qt.Vertical))
        self.controlArea.layout().addWidget(box)

        bgroup = QButtonGroup()

        for i, m in enumerate(self.METHODS[1:-1], 1):
            b = radio_button(m.name,
                             checked=i == self.default_method,
                             group=bgroup,
                             group_id=i)
            box.layout().addWidget(b)

        self.defbggroup = bgroup

        bgroup.buttonClicked[int].connect(self.set_default_method)
        box = group_box(self.tr("Individual attribute settings"),
                        layout=layout(Qt.Horizontal))
        self.controlArea.layout().addWidget(box)

        self.varview = QtGui.QListView(
            selectionMode=QtGui.QListView.ExtendedSelection)
        self.varview.setItemDelegate(DisplayFormatDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._on_var_selection_changed)
        self.selection = self.varview.selectionModel()

        box.layout().addWidget(self.varview)

        method_layout = layout(Qt.Vertical, margins=0)
        box.layout().addLayout(method_layout)

        methodbox = group_box(layout=layout(Qt.Vertical))

        bgroup = QButtonGroup()
        for i, m in enumerate(self.METHODS):
            b = radio_button(m.name, group=bgroup, group_id=i)
            methodbox.layout().addWidget(b)

        assert self.METHODS[-1].short == "value"

        self.value_stack = value_stack = QStackedLayout()
        self.value_combo = QComboBox(activated=self._on_value_changed)
        self.value_line = QLineEdit(editingFinished=self._on_value_changed)
        self.value_line.setValidator(QDoubleValidator())
        value_stack.addWidget(self.value_combo)
        value_stack.addWidget(self.value_line)
        methodbox.layout().addLayout(value_stack)

        bgroup.buttonClicked[int].connect(
            self.set_method_for_current_selection)
        reset_button = push_button("Restore all to default",
                                   clicked=self.reset_var_methods,
                                   default=False,
                                   autoDefault=False)

        method_layout.addWidget(methodbox)
        method_layout.addStretch(2)
        method_layout.addWidget(reset_button)
        self.varmethodbox = methodbox
        self.varbgroup = bgroup

        gui.auto_commit(self.controlArea,
                        self,
                        "autocommit",
                        "Commit",
                        orientation="horizontal",
                        checkbox_label="Commit on any change")
        self.data = None
        self.learner = None

    def set_default_method(self, index):
        """
        Set the current selected default imputation method.
        """
        if self.default_method != index:
            self.default_method = index
            self.defbggroup.button(index).setChecked(True)
            self._invalidate()

    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.data = data
        if data is not None:
            self.varmodel[:] = data.domain.variables
            self.openContext(data.domain)
            self.restore_state(self.variable_methods)
            itemmodels.select_row(self.varview, 0)
        self.unconditional_commit()

    def set_learner(self, learner):
        self.learner = learner

        if self.data is not None and \
                any(state.model.short == "model" for state in
                    map(self.state_for_column, range(len(self.data.domain)))):
            self.commit()

    def restore_state(self, state):
        for i, var in enumerate(self.varmodel):
            key = variable_key(var)
            if key in state:
                index = self.varmodel.index(i)
                self.varmodel.setData(index, state[key], Qt.UserRole)

    def clear(self):
        self.varmodel[:] = []
        self.variable_methods = {}
        self.data = None
        self.modified = False

    def state_for_column(self, column):
        """
        #:: int -> State
        Return the effective imputation state for `column`.

        :param int column:
        :rtype State:

        """
        var = self.varmodel[column]

        state = self.variable_methods.get(variable_key(var), None)
        if state is None or state.method == METHODS[0]:
            state = State(METHODS[self.default_method], ())
        return state

    def imputer_for_column(self, column):
        state = self.state_for_column(column)
        data = self.data
        var = data.domain[column]
        method, params = state
        if method.short == "leave":
            return None
        elif method.short == "avg":
            return column_imputer_average(var, data)
        elif method.short == "model":
            learner = self.learner if self.learner is not None else MeanLearner(
            )
            return column_imputer_by_model(var, data, learner=learner)
        elif method.short == "random":
            return column_imputer_random(var, data)
        elif method.short == "value":
            return column_imputer_defaults(var, data, float(params[0]))
        elif method.short == "as_value":
            return column_imputer_as_value(var, data)
        else:
            assert False

    def commit(self):
        if self.data is not None:
            states = [
                self.state_for_column(i) for i in range(len(self.varmodel))
            ]

            # Columns to filter unknowns by dropping rows.
            filter_columns = [
                i for i, state in enumerate(states)
                if state.method.short == "drop"
            ]

            impute_columns = [
                i for i, state in enumerate(states)
                if state.method.short not in ["drop", "leave"]
            ]

            imputers = [(self.varmodel[i], self.imputer_for_column(i))
                        for i in impute_columns]

            data = self.data

            if imputers:
                table_imputer = ImputerModel(data.domain, dict(imputers))
                data = table_imputer(data)

            if filter_columns:
                filter_ = data_filter.IsDefined(filter_columns)
                data = filter_(data)
        else:
            data = None

        self.send("Data", data)
        self.modified = False

    def _invalidate(self):
        self.modified = True
        self.commit()

    def _on_var_selection_changed(self):
        indexes = self.selection.selectedIndexes()

        vars = [self.varmodel[index.row()] for index in indexes]
        defstate = State(METHODS[0], ())
        states = [
            self.variable_methods.get(variable_key(var), defstate)
            for var in vars
        ]
        all_cont = all(var.is_continuous for var in vars)
        states = list(unique(states))
        method = None
        params = ()
        state = None
        if len(states) == 1:
            state = states[0]
            method, params = state
            mindex = METHODS.index(method)
            self.varbgroup.button(mindex).setChecked(True)
        elif self.varbgroup.checkedButton() is not None:
            self.varbgroup.setExclusive(False)
            self.varbgroup.checkedButton().setChecked(False)
            self.varbgroup.setExclusive(True)

        values, enabled, stack_index = [], False, 0
        value, value_index = "0.0", 0
        if all_cont:
            enabled, stack_index = True, 1
            if method is not None and method.short == "value":
                value = params[0]

        elif len(vars) == 1 and vars[0].is_discrete:
            values, enabled, stack_index = vars[0].values, True, 0
            if method is not None and method.short == "value":
                try:
                    value_index = values.index(params[0])
                except IndexError:
                    pass

        self.value_stack.setCurrentIndex(stack_index)
        self.value_stack.setEnabled(enabled)

        if stack_index == 0:
            self.value_combo.clear()
            self.value_combo.addItems(values)
            self.value_combo.setCurrentIndex(value_index)
        else:
            self.value_line.setText(value)

    def _on_value_changed(self):
        # The "fixed" value in the widget has been changed by the user.
        index = self.varbgroup.checkedId()
        self.set_method_for_current_selection(index)

    def set_method_for_current_selection(self, methodindex):
        indexes = self.selection.selectedIndexes()
        self.set_method_for_indexes(indexes, methodindex)

    def set_method_for_indexes(self, indexes, methodindex):
        method = METHODS[methodindex]
        params = (None, )
        if method.short == "value":
            if self.value_stack.currentIndex() == 0:
                value = self.value_combo.currentIndex()
            else:
                value = self.value_line.text()
            params = (value, )
        elif method.short == "model":
            params = ("model", )
        state = State(method, params)

        for index in indexes:
            self.varmodel.setData(index, state, Qt.UserRole)
            var = self.varmodel[index.row()]
            self.variable_methods[variable_key(var)] = state

        self._invalidate()

    def reset_var_methods(self):
        indexes = map(self.varmodel.index, range(len(self.varmodel)))
        self.set_method_for_indexes(indexes, 0)
示例#23
0
class OWCorrespondenceAnalysis(widget.OWWidget):
    name = "Correspondence Analysis"
    description = "Correspondence analysis for categorical multivariate data."
    icon = "icons/CorrespondenceAnalysis.svg"

    inputs = [("Data", Orange.data.Table, "set_data")]

    Invalidate = QEvent.registerEventType()

    settingsHandler = settings.DomainContextHandler()

    selected_var_indices = settings.ContextSetting([])

    graph_name = "plot.plotItem"

    def __init__(self):
        super().__init__()

        self.data = None
        self.component_x = 0
        self.component_y = 1

        box = gui.widgetBox(self.controlArea, "Variables")
        self.varlist = itemmodels.VariableListModel()
        self.varview = view = QListView(selectionMode=QListView.MultiSelection)
        view.setModel(self.varlist)
        view.selectionModel().selectionChanged.connect(self._var_changed)

        box.layout().addWidget(view)

        axes_box = gui.widgetBox(self.controlArea, "Axes")
        box = gui.widgetBox(axes_box, "Axis X", margin=0)
        box.setFlat(True)
        self.axis_x_cb = gui.comboBox(box,
                                      self,
                                      "component_x",
                                      callback=self._component_changed)

        box = gui.widgetBox(axes_box, "Axis Y", margin=0)
        box.setFlat(True)
        self.axis_y_cb = gui.comboBox(box,
                                      self,
                                      "component_y",
                                      callback=self._component_changed)

        self.infotext = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Contribution to Inertia"), "\n")

        gui.rubber(self.controlArea)
        self.inline_graph_report()

        self.plot = pg.PlotWidget(background="w")
        self.plot.setMenuEnabled(False)
        self.mainArea.layout().addWidget(self.plot)

    def set_data(self, data):
        self.closeContext()
        self.clear()
        self.warning(0)
        self.data = data

        if data is not None:
            self.varlist[:] = [
                var for var in data.domain.variables if var.is_discrete
            ]
            self.selected_var_indices = [0, 1][:len(self.varlist)]
            self.component_x, self.component_y = 0, 1
            self.openContext(data)
            self._restore_selection()


#             self._invalidate()
        self._update_CA()

    def clear(self):
        self.data = None
        self.ca = None
        self.plot.clear()
        self.varlist[:] = []

    def selected_vars(self):
        rows = sorted(ind.row()
                      for ind in self.varview.selectionModel().selectedRows())
        return [self.varlist[i] for i in rows]

    def _restore_selection(self):
        def restore(view, indices):
            with itemmodels.signal_blocking(view.selectionModel()):
                select_rows(view, indices)

        restore(self.varview, self.selected_var_indices)

    def _p_axes(self):
        #         return (0, 1)
        return (self.component_x, self.component_y)

    def _var_changed(self):
        self.selected_var_indices = sorted(
            ind.row() for ind in self.varview.selectionModel().selectedRows())
        self._invalidate()

    def _component_changed(self):
        if self.ca is not None:
            self._setup_plot()
            self._update_info()

    def _invalidate(self):
        self.__invalidated = True
        QApplication.postEvent(self, QEvent(self.Invalidate))

    def customEvent(self, event):
        if event.type() == self.Invalidate:
            self.ca = None
            self.plot.clear()
            self._update_CA()
            return
        return super().customEvent(event)

    def _update_CA(self):
        ca_vars = self.selected_vars()
        if len(ca_vars) == 0:
            return

        multi = len(ca_vars) != 2
        if multi:
            _, ctable = burt_table(self.data, ca_vars)
        else:
            ctable = contingency.get_contingency(self.data, *ca_vars[::-1])

        self.ca = correspondence(ctable, )
        axes = [
            "{}".format(i + 1) for i in range(self.ca.row_factors.shape[1])
        ]
        self.axis_x_cb.clear()
        self.axis_x_cb.addItems(axes)
        self.axis_y_cb.clear()
        self.axis_y_cb.addItems(axes)
        self.component_x, self.component_y = self.component_x, self.component_y

        self._setup_plot()
        self._update_info()

    def _setup_plot(self):
        self.plot.clear()

        points = self.ca
        variables = self.selected_vars()
        colors = colorpalette.ColorPaletteGenerator(len(variables))

        p_axes = self._p_axes()

        if len(variables) == 2:
            row_points = self.ca.row_factors[:, p_axes]
            col_points = self.ca.col_factors[:, p_axes]
            points = [row_points, col_points]
        else:
            points = self.ca.row_factors[:, p_axes]
            counts = [len(var.values) for var in variables]
            range_indices = numpy.cumsum([0] + counts)
            ranges = zip(range_indices, range_indices[1:])
            points = [points[s:e] for s, e in ranges]

        for i, (v, points) in enumerate(zip(variables, points)):
            color_outline = colors[i]
            color_outline.setAlpha(200)
            color = QColor(color_outline)
            color.setAlpha(120)
            item = ScatterPlotItem(
                x=points[:, 0],
                y=points[:, 1],
                brush=QBrush(color),
                pen=pg.mkPen(color_outline.darker(120), width=1.5),
                size=numpy.full((points.shape[0], ), 10.1),
            )
            self.plot.addItem(item)

            for name, point in zip(v.values, points):
                item = pg.TextItem(name, anchor=(0.5, 0))
                self.plot.addItem(item)
                item.setPos(point[0], point[1])

        inertia = self.ca.inertia_of_axis()
        inertia = 100 * inertia / numpy.sum(inertia)

        ax = self.plot.getAxis("bottom")
        ax.setLabel("Component {} ({:.1f}%)".format(p_axes[0] + 1,
                                                    inertia[p_axes[0]]))
        ax = self.plot.getAxis("left")
        ax.setLabel("Component {} ({:.1f}%)".format(p_axes[1] + 1,
                                                    inertia[p_axes[1]]))

    def _update_info(self):
        if self.ca is None:
            self.infotext.setText("\n\n")
        else:
            fmt = ("Axis 1: {:.2f}\n" "Axis 2: {:.2f}")
            inertia = self.ca.inertia_of_axis()
            inertia = 100 * inertia / numpy.sum(inertia)

            ax1, ax2 = self._p_axes()
            self.infotext.setText(fmt.format(inertia[ax1], inertia[ax2]))

    def send_report(self):
        if self.data is None:
            return

        vars = self.selected_vars()
        if not vars:
            return

        items = OrderedDict()
        items["Data instances"] = len(self.data)
        if len(vars) == 1:
            items["Selected variable"] = vars[0]
        else:
            items["Selected variables"] = "{} and {}".format(
                ", ".join(var.name for var in vars[:-1]), vars[-1].name)
        self.report_items(items)

        self.report_plot()
示例#24
0
class OWPCA(widget.OWWidget):
    name = "PCA"
    description = "Principal component analysis with a scree-diagram."
    icon = "icons/PCA.svg"
    priority = 3050
    keywords = ["principal component analysis", "linear transformation"]

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        transformed_data = Output("Transformed Data",
                                  Table,
                                  replaces=["Transformed data"])
        components = Output("Components", Table)
        pca = Output("PCA", PCA, dynamic=False)

    settingsHandler = settings.DomainContextHandler()

    ncomponents = settings.Setting(2)
    variance_covered = settings.Setting(100)
    auto_commit = settings.Setting(True)
    normalize = settings.ContextSetting(True)
    maxp = settings.Setting(20)
    axis_labels = settings.Setting(10)

    graph_name = "plot.plotItem"

    class Warning(widget.OWWidget.Warning):
        trivial_components = widget.Msg(
            "All components of the PCA are trivial (explain 0 variance). "
            "Input data is constant (or near constant).")

    class Error(widget.OWWidget.Error):
        no_features = widget.Msg("At least 1 feature is required")
        no_instances = widget.Msg("At least 1 data instance is required")

    def __init__(self):
        super().__init__()
        self.data = None

        self._pca = None
        self._transformed = None
        self._variance_ratio = None
        self._cumulative = None
        self._init_projector()

        # Components Selection
        box = gui.vBox(self.controlArea, "Components Selection")
        form = QFormLayout()
        box.layout().addLayout(form)

        self.components_spin = gui.spin(
            box,
            self,
            "ncomponents",
            1,
            MAX_COMPONENTS,
            callback=self._update_selection_component_spin,
            keyboardTracking=False)
        self.components_spin.setSpecialValueText("All")

        self.variance_spin = gui.spin(
            box,
            self,
            "variance_covered",
            1,
            100,
            callback=self._update_selection_variance_spin,
            keyboardTracking=False)
        self.variance_spin.setSuffix("%")

        form.addRow("Components:", self.components_spin)
        form.addRow("Explained variance:", self.variance_spin)

        # Options
        self.options_box = gui.vBox(self.controlArea, "Options")
        self.normalize_box = gui.checkBox(self.options_box,
                                          self,
                                          "normalize",
                                          "Normalize variables",
                                          callback=self._update_normalize)

        self.maxp_spin = gui.spin(self.options_box,
                                  self,
                                  "maxp",
                                  1,
                                  MAX_COMPONENTS,
                                  label="Show only first",
                                  callback=self._setup_plot,
                                  keyboardTracking=False)

        self.controlArea.layout().addStretch()

        gui.auto_apply(self.controlArea, self, "auto_commit")

        self.plot = SliderGraph("Principal Components",
                                "Proportion of variance", self._on_cut_changed)

        self.mainArea.layout().addWidget(self.plot)
        self._update_normalize()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.clear_messages()
        self.clear()
        self.information()
        self.data = None
        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.information("Data has been sampled")
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)
        if isinstance(data, Table):
            if not data.domain.attributes:
                self.Error.no_features()
                self.clear_outputs()
                return
            if not data:
                self.Error.no_instances()
                self.clear_outputs()
                return

        self.openContext(data)
        self._init_projector()

        self.data = data
        self.fit()

    def fit(self):
        self.clear()
        self.Warning.trivial_components.clear()
        if self.data is None:
            return

        data = self.data

        if self.normalize:
            self._pca_projector.preprocessors = \
                self._pca_preprocessors + [preprocess.Normalize(center=False)]
        else:
            self._pca_projector.preprocessors = self._pca_preprocessors

        if not isinstance(data, SqlTable):
            pca = self._pca_projector(data)
            variance_ratio = pca.explained_variance_ratio_
            cumulative = numpy.cumsum(variance_ratio)

            if numpy.isfinite(cumulative[-1]):
                self.components_spin.setRange(0, len(cumulative))
                self._pca = pca
                self._variance_ratio = variance_ratio
                self._cumulative = cumulative
                self._setup_plot()
            else:
                self.Warning.trivial_components()

            self.unconditional_commit()

    def clear(self):
        self._pca = None
        self._transformed = None
        self._variance_ratio = None
        self._cumulative = None
        self.plot.clear_plot()

    def clear_outputs(self):
        self.Outputs.transformed_data.send(None)
        self.Outputs.components.send(None)
        self.Outputs.pca.send(self._pca_projector)

    def _setup_plot(self):
        if self._pca is None:
            self.plot.clear_plot()
            return

        explained_ratio = self._variance_ratio
        explained = self._cumulative
        cutpos = self._nselected_components()
        p = min(len(self._variance_ratio), self.maxp)

        self.plot.update(numpy.arange(1, p + 1),
                         [explained_ratio[:p], explained[:p]],
                         [Qt.red, Qt.darkYellow],
                         cutpoint_x=cutpos,
                         names=LINE_NAMES)

        self._update_axis()

    def _on_cut_changed(self, components):
        if components == self.ncomponents \
                or self.ncomponents == 0 \
                or self._pca is not None \
                and components == len(self._variance_ratio):
            return

        self.ncomponents = components
        if self._pca is not None:
            var = self._cumulative[components - 1]
            if numpy.isfinite(var):
                self.variance_covered = int(var * 100)

        self._invalidate_selection()

    def _update_selection_component_spin(self):
        # cut changed by "ncomponents" spin.
        if self._pca is None:
            self._invalidate_selection()
            return

        if self.ncomponents == 0:
            # Special "All" value
            cut = len(self._variance_ratio)
        else:
            cut = self.ncomponents

        var = self._cumulative[cut - 1]
        if numpy.isfinite(var):
            self.variance_covered = int(var * 100)

        self.plot.set_cut_point(cut)
        self._invalidate_selection()

    def _update_selection_variance_spin(self):
        # cut changed by "max variance" spin.
        if self._pca is None:
            return

        cut = numpy.searchsorted(self._cumulative,
                                 self.variance_covered / 100.0) + 1
        cut = min(cut, len(self._cumulative))
        self.ncomponents = cut
        self.plot.set_cut_point(cut)
        self._invalidate_selection()

    def _update_normalize(self):
        self.fit()
        if self.data is None:
            self._invalidate_selection()

    def _init_projector(self):
        self._pca_projector = PCA(n_components=MAX_COMPONENTS, random_state=0)
        self._pca_projector.component = self.ncomponents
        self._pca_preprocessors = PCA.preprocessors

    def _nselected_components(self):
        """Return the number of selected components."""
        if self._pca is None:
            return 0

        if self.ncomponents == 0:
            # Special "All" value
            max_comp = len(self._variance_ratio)
        else:
            max_comp = self.ncomponents

        var_max = self._cumulative[max_comp - 1]
        if var_max != numpy.floor(self.variance_covered / 100.0):
            cut = max_comp
            assert numpy.isfinite(var_max)
            self.variance_covered = int(var_max * 100)
        else:
            self.ncomponents = cut = numpy.searchsorted(
                self._cumulative, self.variance_covered / 100.0) + 1
        return cut

    def _invalidate_selection(self):
        self.commit()

    def _update_axis(self):
        p = min(len(self._variance_ratio), self.maxp)
        axis = self.plot.getAxis("bottom")
        d = max((p - 1) // (self.axis_labels - 1), 1)
        axis.setTicks([[(i, str(i)) for i in range(1, p + 1, d)]])

    def commit(self):
        transformed = components = None
        if self._pca is not None:
            if self._transformed is None:
                # Compute the full transform (MAX_COMPONENTS components) once.
                self._transformed = self._pca(self.data)
            transformed = self._transformed

            domain = Domain(transformed.domain.attributes[:self.ncomponents],
                            self.data.domain.class_vars,
                            self.data.domain.metas)
            transformed = transformed.from_table(domain, transformed)
            # prevent caching new features by defining compute_value
            dom = Domain([
                ContinuousVariable(a.name, compute_value=lambda _: None)
                for a in self._pca.orig_domain.attributes
            ],
                         metas=[StringVariable(name='component')])
            metas = numpy.array(
                [['PC{}'.format(i + 1) for i in range(self.ncomponents)]],
                dtype=object).T
            components = Table(dom,
                               self._pca.components_[:self.ncomponents],
                               metas=metas)
            components.name = 'components'

        self._pca_projector.component = self.ncomponents
        self.Outputs.transformed_data.send(transformed)
        self.Outputs.components.send(components)
        self.Outputs.pca.send(self._pca_projector)

    def send_report(self):
        if self.data is None:
            return
        self.report_items(
            (("Normalize data", str(self.normalize)), ("Selected components",
                                                       self.ncomponents),
             ("Explained variance", "{:.3f} %".format(self.variance_covered))))
        self.report_plot()

    @classmethod
    def migrate_settings(cls, settings, version):
        if "variance_covered" in settings:
            # Due to the error in gh-1896 the variance_covered was persisted
            # as a NaN value, causing a TypeError in the widgets `__init__`.
            vc = settings["variance_covered"]
            if isinstance(vc, numbers.Real):
                if numpy.isfinite(vc):
                    vc = int(vc)
                else:
                    vc = 100
                settings["variance_covered"] = vc
        if settings.get("ncomponents", 0) > MAX_COMPONENTS:
            settings["ncomponents"] = MAX_COMPONENTS

        # Remove old `decomposition_idx` when SVD was still included
        settings.pop("decomposition_idx", None)

        # Remove RemotePCA settings
        settings.pop("batch_size", None)
        settings.pop("address", None)
        settings.pop("auto_update", None)
class OWKEGGPathwayBrowser(widget.OWWidget):
    name = "KEGG Pathways"
    description = "Browse KEGG pathways that include an input set of genes."
    icon = "../widgets/icons/OWKEGGPathwayBrowser.svg"
    priority = 8

    inputs = [("Data", Orange.data.Table, "SetData", widget.Default),
              ("Reference", Orange.data.Table, "SetRefData")]
    outputs = [("Selected Data", Orange.data.Table, widget.Default),
               ("Unselected Data", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    organismIndex = settings.ContextSetting(0)
    geneAttrIndex = settings.ContextSetting(0)
    useAttrNames = settings.ContextSetting(False)

    autoCommit = settings.Setting(False)
    autoResize = settings.Setting(True)
    useReference = settings.Setting(False)
    showOrthology = settings.Setting(True)

    Ready, Initializing, Running = 0, 1, 2

    def __init__(self, parent=None):
        super().__init__(parent)

        self.organismCodes = []
        self._changedFlag = False
        self.__invalidated = False
        self.__runstate = OWKEGGPathwayBrowser.Initializing
        self.__in_setProgress = False

        self.controlArea.setMaximumWidth(250)
        box = gui.widgetBox(self.controlArea, "Info")
        self.infoLabel = gui.widgetLabel(box, "No data on input\n")

        # Organism selection.
        box = gui.widgetBox(self.controlArea, "Organism")
        self.organismComboBox = gui.comboBox(
            box, self, "organismIndex",
            items=[],
            callback=self.Update,
            addSpace=True,
            tooltip="Select the organism of the input genes")

        # Selection of genes attribute
        box = gui.widgetBox(self.controlArea, "Gene attribute")
        self.geneAttrCandidates = itemmodels.VariableListModel(parent=self)
        self.geneAttrCombo = gui.comboBox(
            box, self, "geneAttrIndex", callback=self.Update)
        self.geneAttrCombo.setModel(self.geneAttrCandidates)

        gui.checkBox(box, self, "useAttrNames",
                    "Use variable names", disables=[(-1, self.geneAttrCombo)],
                    callback=self.Update)

        self.geneAttrCombo.setDisabled(bool(self.useAttrNames))

        gui.separator(self.controlArea)

        gui.checkBox(self.controlArea, self, "useReference",
                    "From signal", box="Reference", callback=self.Update)

        gui.separator(self.controlArea)

        gui.checkBox(self.controlArea, self, "showOrthology",
                     "Show pathways in full orthology", box="Orthology",
                     callback=self.UpdateListView)

        gui.checkBox(self.controlArea, self, "autoResize",
                     "Resize to fit", box="Image",
                     callback=self.UpdatePathwayViewTransform)

        box = gui.widgetBox(self.controlArea, "Cache Control")

        gui.button(box, self, "Clear cache",
                   callback=self.ClearCache,
                   tooltip="Clear all locally cached KEGG data.",
                   default=False, autoDefault=False)

        gui.separator(self.controlArea)

        gui.auto_commit(self.controlArea, self, "autoCommit", "Commit")

        gui.rubber(self.controlArea)

        spliter = QSplitter(Qt.Vertical, self.mainArea)
        self.pathwayView = PathwayView(self, spliter)
        self.pathwayView.scene().selectionChanged.connect(
            self._onSelectionChanged
        )
        self.mainArea.layout().addWidget(spliter)

        self.listView = QTreeWidget(
            allColumnsShowFocus=True,
            selectionMode=QTreeWidget.SingleSelection,
            sortingEnabled=True,
            maximumHeight=200)

        spliter.addWidget(self.listView)

        self.listView.setColumnCount(4)
        self.listView.setHeaderLabels(
            ["Pathway", "P value", "Genes", "Reference"])

        self.listView.itemSelectionChanged.connect(self.UpdatePathwayView)

        select = QAction(
            "Select All", self,
            shortcut=QKeySequence.SelectAll
        )
        select.triggered.connect(self.selectAll)
        self.addAction(select)

        self.data = None
        self.refData = None

        self._executor = concurrent.ThreadExecutor()
        self.setEnabled(False)
        self.setBlocking(True)
        progress = concurrent.methodinvoke(self, "setProgress", (float,))

        def get_genome():
            """Return a KEGGGenome with the common org entries precached."""
            genome = kegg.KEGGGenome()

            essential = genome.essential_organisms()
            common = genome.common_organisms()
            # Remove duplicates of essential from common.
            # (essential + common list as defined here will be used in the
            # GUI.)
            common = [c for c in common if c not in essential]

            # TODO: Add option to specify additional organisms not
            # in the common list.

            keys = list(map(genome.org_code_to_entry_key, essential + common))

            genome.pre_cache(keys, progress_callback=progress)
            return (keys, genome)

        self._genomeTask = task = concurrent.Task(function=get_genome)
        task.finished.connect(self.__initialize_finish)

        self.progressBarInit()
        self.infoLabel.setText("Fetching organism definitions\n")
        self._executor.submit(task)

    def __initialize_finish(self):
        if self.__runstate != OWKEGGPathwayBrowser.Initializing:
            return

        try:
            keys, genome = self._genomeTask.result()
        except Exception as err:
            self.error(0, str(err))
            raise

        self.progressBarFinished()
        self.setEnabled(True)
        self.setBlocking(False)

        entries = [genome[key] for key in keys]
        items = [entry.definition for entry in entries]
        codes = [entry.organism_code for entry in entries]

        self.organismCodes = codes
        self.organismComboBox.clear()
        self.organismComboBox.addItems(items)
        self.organismComboBox.setCurrentIndex(self.organismIndex)

        self.infoLabel.setText("No data on input\n")

    def Clear(self):
        """
        Clear the widget state.
        """
        self.queryGenes = []
        self.referenceGenes = []
        self.genes = {}
        self.uniqueGenesDict = {}
        self.revUniqueGenesDict = {}
        self.pathways = {}
        self.org = None
        self.geneAttrCandidates[:] = []

        self.infoLabel.setText("No data on input\n")
        self.listView.clear()
        self.pathwayView.SetPathway(None)

        self.send("Selected Data", None)
        self.send("Unselected Data", None)

    def SetData(self, data=None):
        if self.__runstate == OWKEGGPathwayBrowser.Initializing:
            self.__initialize_finish()

        self.data = data
        self.warning(0)
        self.error(0)
        self.information(0)

        if data is not None:
            vars = data.domain.variables + data.domain.metas
            vars = [var for var in vars
                    if isinstance(var, Orange.data.StringVariable)]
            self.geneAttrCandidates[:] = vars

            # Try to guess the gene name variable
            if vars:
                names_lower = [v.name.lower() for v in vars]
                scores = [(name == "gene", "gene" in name)
                          for name in names_lower]
                imax, _ = max(enumerate(scores), key=itemgetter(1))
            else:
                imax = -1

            self.geneAttrIndex = imax

            taxid = data_hints.get_hint(data, TAX_ID, None)
            if taxid:
                try:
                    code = kegg.from_taxid(taxid)
                    self.organismIndex = self.organismCodes.index(code)
                except Exception as ex:
                    print(ex, taxid)

            self.useAttrNames = data_hints.get_hint(data, GENE_NAME, self.useAttrNames)

            if len(self.geneAttrCandidates) == 0:
                self.useAttrNames = True
                self.geneAttrIndex = -1
            else:
                self.geneAttrIndex = min(self.geneAttrIndex,
                                         len(self.geneAttrCandidates) - 1)
        else:
            self.Clear()

        self.__invalidated = True

    def SetRefData(self, data=None):
        self.refData = data
        self.information(1)

        if data is not None and self.useReference:
            self.__invalidated = True

    def handleNewSignals(self):
        if self.__invalidated:
            self.Update()
            self.__invalidated = False

    def UpdateListView(self):
        self.bestPValueItem = None
        self.listView.clear()
        if not self.data:
            return

        allPathways = self.org.pathways()
        allRefPathways = kegg.pathways("map")

        items = []
        kegg_pathways = kegg.KEGGPathways()

        org_code = self.organismCodes[min(self.organismIndex,
                                          len(self.organismCodes) - 1)]

        if self.showOrthology:
            self.koOrthology = kegg.KEGGBrite("ko00001")
            self.listView.setRootIsDecorated(True)
            path_ids = set([s[-5:] for s in self.pathways.keys()])

            def _walkCollect(koEntry):
                num = koEntry.title[:5] if koEntry.title else None
                if num in path_ids:
                    return ([koEntry] +
                            reduce(lambda li, c: li + _walkCollect(c),
                                   [child for child in koEntry.entries],
                                   []))
                else:
                    c = reduce(lambda li, c: li + _walkCollect(c),
                               [child for child in koEntry.entries],
                               [])
                    return c + (c and [koEntry] or [])

            allClasses = reduce(lambda li1, li2: li1 + li2,
                                [_walkCollect(c) for c in self.koOrthology],
                                [])

            def _walkCreate(koEntry, lvItem):
                item = QTreeWidgetItem(lvItem)
                id = "path:" + org_code + koEntry.title[:5]

                if koEntry.title[:5] in path_ids:
                    p = kegg_pathways.get_entry(id)
                    if p is None:
                        # In case the genesets still have obsolete entries
                        name = koEntry.title
                    else:
                        name = p.name
                    genes, p_value, ref = self.pathways[id]
                    item.setText(0, name)
                    item.setText(1, "%.5f" % p_value)
                    item.setText(2, "%i of %i" % (len(genes), len(self.genes)))
                    item.setText(3, "%i of %i" % (ref, len(self.referenceGenes)))
                    item.pathway_id = id if p is not None else None
                else:
                    if id in allPathways:
                        text = kegg_pathways.get_entry(id).name
                    else:
                        text = koEntry.title
                    item.setText(0, text)

                    if id in allPathways:
                        item.pathway_id = id
                    elif "path:map" + koEntry.title[:5] in allRefPathways:
                        item.pathway_id = "path:map" + koEntry.title[:5]
                    else:
                        item.pathway_id = None

                for child in koEntry.entries:
                    if child in allClasses:
                        _walkCreate(child, item)

            for koEntry in self.koOrthology:
                if koEntry in allClasses:
                    _walkCreate(koEntry, self.listView)

            self.listView.update()
        else:
            self.listView.setRootIsDecorated(False)
            pathways = self.pathways.items()
            pathways = sorted(pathways, key=lambda item: item[1][1])

            for id, (genes, p_value, ref) in pathways:
                item = QTreeWidgetItem(self.listView)
                item.setText(0, kegg_pathways.get_entry(id).name)
                item.setText(1, "%.5f" % p_value)
                item.setText(2, "%i of %i" % (len(genes), len(self.genes)))
                item.setText(3, "%i of %i" % (ref, len(self.referenceGenes)))
                item.pathway_id = id
                items.append(item)

        self.bestPValueItem = items and items[0] or None
        self.listView.expandAll()
        for i in range(4):
            self.listView.resizeColumnToContents(i)

        if self.bestPValueItem:
            index = self.listView.indexFromItem(self.bestPValueItem)
            self.listView.selectionModel().select(
                index, QItemSelectionModel.ClearAndSelect
            )

    def UpdatePathwayView(self):
        items = self.listView.selectedItems()

        if len(items) > 0:
            item = items[0]
        else:
            item = None

        self.commit()
        item = item or self.bestPValueItem
        if not item or not item.pathway_id:
            self.pathwayView.SetPathway(None)
            return

        def get_kgml_and_image(pathway_id):
            """Return an initialized KEGGPathway with pre-cached data"""
            p = kegg.KEGGPathway(pathway_id)
            p._get_kgml()  # makes sure the kgml file is downloaded
            p._get_image_filename()  # makes sure the image is downloaded
            return (pathway_id, p)

        self.setEnabled(False)
        self._pathwayTask = concurrent.Task(
            function=lambda: get_kgml_and_image(item.pathway_id)
        )
        self._pathwayTask.finished.connect(self._onPathwayTaskFinshed)
        self._executor.submit(self._pathwayTask)

    def _onPathwayTaskFinshed(self):
        self.setEnabled(True)
        pathway_id, self.pathway = self._pathwayTask.result()
        self.pathwayView.SetPathway(
            self.pathway,
            self.pathways.get(pathway_id, [[]])[0]
        )

    def UpdatePathwayViewTransform(self):
        self.pathwayView.updateTransform()

    def Update(self):
        """
        Update (recompute enriched pathways) the widget state.
        """
        if not self.data:
            return

        self.error(0)
        self.information(0)

        # XXX: Check data in setData, do not even allow this to be executed if
        # data has no genes
        try:
            genes = self.GeneNamesFromData(self.data)
        except ValueError:
            self.error(0, "Cannot extract gene names from input.")
            genes = []

        if not self.useAttrNames and any("," in gene for gene in genes):
            genes = reduce(add, (split_and_strip(gene, ",")
                                 for gene in genes),
                           [])
            self.information(0,
                             "Separators detected in input gene names. "
                             "Assuming multiple genes per instance.")

        self.queryGenes = genes

        self.information(1)
        reference = None
        if self.useReference and self.refData:
            reference = self.GeneNamesFromData(self.refData)
            if not self.useAttrNames \
                    and any("," in gene for gene in reference):
                reference = reduce(add, (split_and_strip(gene, ",")
                                         for gene in reference),
                                   [])
                self.information(1,
                                 "Separators detected in reference gene "
                                 "names. Assuming multiple genes per "
                                 "instance.")

        org_code = self.SelectedOrganismCode()

        from orangecontrib.bioinformatics.ncbi.gene import GeneMatcher
        gm = GeneMatcher(kegg.to_taxid(org_code))
        gm.genes = genes
        gm.run_matcher()
        mapped_genes = {gene: str(ncbi_id) for gene, ncbi_id in gm.map_input_to_ncbi().items()}

        def run_enrichment(org_code, genes, reference=None, progress=None):
            org = kegg.KEGGOrganism(org_code)
            if reference is None:
                reference = org.get_ncbi_ids()

            # This is here just to keep widget working without any major changes.
            # map not needed, geneMatcher will not work on widget level.
            unique_genes = genes
            unique_ref_genes = dict([(gene, gene) for gene in set(reference)])

            taxid = kegg.to_taxid(org.org_code)
            # Map the taxid back to standard 'common' taxids
            # (as used by 'geneset') if applicable
            r_tax_map = dict((v, k) for k, v in
                             kegg.KEGGGenome.TAXID_MAP.items())
            if taxid in r_tax_map:
                taxid = r_tax_map[taxid]

            # We use the kegg pathway gene sets provided by 'geneset' for
            # the enrichment calculation.

            kegg_api = kegg.api.CachedKeggApi()
            linkmap = kegg_api.link(org.org_code, "pathway")
            converted_ids = kegg_api.conv(org.org_code, 'ncbi-geneid')
            kegg_sets = relation_list_to_multimap(linkmap, dict((gene.upper(), ncbi.split(':')[-1])
                                                                for ncbi, gene in converted_ids))

            kegg_sets = geneset.GeneSets(input=kegg_sets)

            pathways = pathway_enrichment(
                kegg_sets, unique_genes.values(),
                unique_ref_genes.keys(),
                callback=progress
            )
            # Ensure that pathway entries are pre-cached for later use in the
            # list/tree view
            kegg_pathways = kegg.KEGGPathways()
            kegg_pathways.pre_cache(
                pathways.keys(), progress_callback=progress
            )

            return pathways, org, unique_genes, unique_ref_genes

        self.progressBarInit()
        self.setEnabled(False)
        self.infoLabel.setText("Retrieving...\n")

        progress = concurrent.methodinvoke(self, "setProgress", (float,))

        self._enrichTask = concurrent.Task(
            function=lambda:
                run_enrichment(org_code, mapped_genes, reference, progress)
        )
        self._enrichTask.finished.connect(self._onEnrichTaskFinished)
        self._executor.submit(self._enrichTask)

    def _onEnrichTaskFinished(self):
        self.setEnabled(True)
        self.setBlocking(False)
        try:
            pathways, org, unique_genes, unique_ref_genes = \
                self._enrichTask.result()
        except Exception:
            raise

        self.progressBarFinished()

        self.org = org
        self.genes = unique_genes.keys()
        self.uniqueGenesDict = {ncbi_id: input_name for input_name, ncbi_id in unique_genes.items()}
        self.revUniqueGenesDict = dict([(val, key) for key, val in
                                        self.uniqueGenesDict.items()])
        self.referenceGenes = unique_ref_genes.keys()
        self.pathways = pathways

        if not self.pathways:
            self.warning(0, "No enriched pathways found.")
        else:
            self.warning(0)

        count = len(set(self.queryGenes))
        self.infoLabel.setText(
            "%i unique gene names on input\n"
            "%i (%.1f%%) genes names matched" %
            (count, len(unique_genes),
             100.0 * len(unique_genes) / count if count else 0.0)
        )

        self.UpdateListView()

    @Slot(float)
    def setProgress(self, value):
        if self.__in_setProgress:
            return

        self.__in_setProgress = True
        self.progressBarSet(value)
        self.__in_setProgress = False

    def GeneNamesFromData(self, data):
        """
        Extract and return gene names from `data`.
        """
        if self.useAttrNames:
            genes = [str(v.name).strip() for v in data.domain.attributes]
        elif self.geneAttrCandidates:
            assert 0 <= self.geneAttrIndex < len(self.geneAttrCandidates)
            geneAttr = self.geneAttrCandidates[self.geneAttrIndex]
            genes = [str(e[geneAttr]) for e in data
                     if not numpy.isnan(e[geneAttr])]
        else:
            raise ValueError("No gene names in data.")
        return genes

    def SelectedOrganismCode(self):
        """
        Return the selected organism code.
        """
        return self.organismCodes[min(self.organismIndex,
                                      len(self.organismCodes) - 1)]

    def selectAll(self):
        """
        Select all items in the pathway view.
        """
        changed = False
        scene = self.pathwayView.scene()
        with disconnected(scene.selectionChanged, self._onSelectionChanged):
            for item in scene.items():
                if item.flags() & QGraphicsItem.ItemIsSelectable and \
                        not item.isSelected():
                    item.setSelected(True)
                    changed = True
        if changed:
            self._onSelectionChanged()

    def _onSelectionChanged(self):
        # Item selection in the pathwayView/scene has changed
        self.commit()

    def commit(self):
        if self.data:
            selectedItems = self.pathwayView.scene().selectedItems()
            selectedGenes = reduce(set.union, [item.marked_objects
                                               for item in selectedItems],
                                   set())

            if self.useAttrNames:
                selected = [self.data.domain[self.uniqueGenesDict[gene]]
                            for gene in selectedGenes]
#                 newDomain = Orange.data.Domain(selectedVars, 0)
                data = self.data[:, selected]
#                 data = Orange.data.Table(newDomain, self.data)
                self.send("Selected Data", data)
            elif self.geneAttrCandidates:
                assert 0 <= self.geneAttrIndex < len(self.geneAttrCandidates)
                geneAttr = self.geneAttrCandidates[self.geneAttrIndex]
                selectedIndices = []
                otherIndices = []
                for i, ex in enumerate(self.data):
                    names = [self.revUniqueGenesDict.get(name, None)
                             for name in split_and_strip(str(ex[geneAttr]), ",")]
                    if any(name and name in selectedGenes for name in names):
                        selectedIndices.append(i)
                    else:
                        otherIndices.append(i)

                if selectedIndices:
                    selected = self.data[selectedIndices]
                else:
                    selected = None

                if otherIndices:
                    other = self.data[otherIndices]
                else:
                    other = None

                self.send("Selected Data", selected)
                self.send("Unselected Data", other)
        else:
            self.send("Selected Data", None)
            self.send("Unselected Data", None)

    def ClearCache(self):
        kegg.caching.clear_cache()

    def onDeleteWidget(self):
        """
        Called before the widget is removed from the canvas.
        """
        super().onDeleteWidget()

        self.org = None
        self._executor.shutdown(wait=False)
        gc.collect()  # Force collection (WHY?)

    def sizeHint(self):
        return QSize(1024, 720)
示例#26
0
class OWDistributions(OWWidget):
    name = "Distributions"
    description = "Display value distributions of a data feature in a graph."
    icon = "icons/Distribution.svg"
    priority = 120
    keywords = ["histogram"]

    class Inputs:
        data = Input("Data", Table, doc="Set the input dataset")

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        histogram_data = Output("Histogram Data", Table)

    class Error(OWWidget.Error):
        no_defined_values_var = \
            Msg("Variable '{}' does not have any defined values")
        no_defined_values_pair = \
            Msg("No data instances with '{}' and '{}' defined")

    class Warning(OWWidget.Warning):
        ignored_nans = Msg("Data instances with missing values are ignored")

    settingsHandler = settings.DomainContextHandler()
    var = settings.ContextSetting(None)
    cvar = settings.ContextSetting(None)
    selection = settings.ContextSetting(set(), schema_only=True)
    # number_of_bins must be a context setting because selection depends on it
    number_of_bins = settings.ContextSetting(5, schema_only=True)

    fitted_distribution = settings.Setting(0)
    hide_bars = settings.Setting(False)
    show_probs = settings.Setting(False)
    stacked_columns = settings.Setting(False)
    cumulative_distr = settings.Setting(False)
    sort_by_freq = settings.Setting(False)
    kde_smoothing = settings.Setting(10)

    auto_apply = settings.Setting(True)

    graph_name = "plot"

    Fitters = (
        ("None", None, (), ()),
        ("Normal", norm, ("loc", "scale"), ("μ", "σ")),
        ("Beta", beta, ("a", "b", "loc", "scale"),
         ("α", "β", "-loc", "-scale")),
        ("Gamma", gamma, ("a", "loc", "scale"), ("α", "β", "-loc", "-scale")),
        ("Rayleigh", rayleigh, ("loc", "scale"), ("-loc", "σ")),
        ("Pareto", pareto, ("b", "loc", "scale"), ("α", "-loc", "-scale")),
        ("Exponential", expon, ("loc", "scale"), ("-loc", "λ")),
        ("Kernel density", AshCurve, ("a",), ("",))
    )

    DragNone, DragAdd, DragRemove = range(3)

    def __init__(self):
        super().__init__()
        self.data = None
        self.valid_data = self.valid_group_data = None
        self.bar_items = []
        self.curve_items = []
        self.curve_descriptions = None
        self.binnings = []

        self.last_click_idx = None
        self.drag_operation = self.DragNone
        self.key_operation = None
        self._user_var_bins = {}

        varview = gui.listView(
            self.controlArea, self, "var", box="Variable",
            model=DomainModel(valid_types=DomainModel.PRIMITIVE,
                              separators=False),
            callback=self._on_var_changed,
            viewType=ListViewSearch
        )
        gui.checkBox(
            varview.box, self, "sort_by_freq", "Sort categories by frequency",
            callback=self._on_sort_by_freq, stateWhenDisabled=False)

        box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
        gui.comboBox(
            box, self, "fitted_distribution", label="Fitted distribution",
            orientation=Qt.Horizontal, items=(name[0] for name in self.Fitters),
            callback=self._on_fitted_dist_changed)
        slider = gui.hSlider(
            box, self, "number_of_bins",
            label="Bin width", orientation=Qt.Horizontal,
            minValue=0, maxValue=max(1, len(self.binnings) - 1),
            createLabel=False, callback=self._on_bins_changed)
        self.bin_width_label = gui.widgetLabel(slider.box)
        self.bin_width_label.setFixedWidth(35)
        self.bin_width_label.setAlignment(Qt.AlignRight)
        slider.sliderReleased.connect(self._on_bin_slider_released)
        self.smoothing_box = gui.hSlider(
            box, self, "kde_smoothing",
            label="Smoothing", orientation=Qt.Horizontal,
            minValue=2, maxValue=20, callback=self.replot, disabled=True)
        gui.checkBox(
            box, self, "hide_bars", "Hide bars", stateWhenDisabled=False,
            callback=self._on_hide_bars_changed,
            disabled=not self.fitted_distribution)

        box = gui.vBox(self.controlArea, "Columns")
        gui.comboBox(
            box, self, "cvar", label="Split by", orientation=Qt.Horizontal,
            searchable=True,
            model=DomainModel(placeholder="(None)",
                              valid_types=(DiscreteVariable), ),
            callback=self._on_cvar_changed, contentsLength=18)
        gui.checkBox(
            box, self, "stacked_columns", "Stack columns",
            callback=self.replot)
        gui.checkBox(
            box, self, "show_probs", "Show probabilities",
            callback=self._on_show_probabilities_changed)
        gui.checkBox(
            box, self, "cumulative_distr", "Show cumulative distribution",
            callback=self._on_show_cumulative)

        gui.auto_apply(self.buttonsArea, self, commit=self.apply)

        self._set_smoothing_visibility()
        self._setup_plots()
        self._setup_legend()

    def _setup_plots(self):
        def add_new_plot(zvalue):
            plot = pg.ViewBox(enableMouse=False, enableMenu=False)
            self.ploti.scene().addItem(plot)
            pg.AxisItem("right").linkToView(plot)
            plot.setXLink(self.ploti)
            plot.setZValue(zvalue)
            return plot

        self.plotview = DistributionWidget()
        self.plotview.item_clicked.connect(self._on_item_clicked)
        self.plotview.blank_clicked.connect(self._on_blank_clicked)
        self.plotview.mouse_released.connect(self._on_end_selecting)
        self.plotview.setRenderHint(QPainter.Antialiasing)
        box = gui.vBox(self.mainArea, box=True, margin=0)
        box.layout().addWidget(self.plotview)
        self.ploti = pg.PlotItem(
            enableMenu=False, enableMouse=False,
            axisItems={"bottom": ElidedAxisNoUnits("bottom")})
        self.plot = self.ploti.vb
        self.plot.setMouseEnabled(False, False)
        self.ploti.hideButtons()
        self.plotview.setCentralItem(self.ploti)

        self.plot_pdf = add_new_plot(10)
        self.plot_mark = add_new_plot(-10)
        self.plot_mark.setYRange(0, 1)
        self.ploti.vb.sigResized.connect(self.update_views)
        self.update_views()

        pen = QPen(self.palette().color(QPalette.Text))
        self.ploti.getAxis("bottom").setPen(pen)
        left = self.ploti.getAxis("left")
        left.setPen(pen)
        left.setStyle(stopAxisAtTick=(True, True))

    def _setup_legend(self):
        self._legend = LegendItem()
        self._legend.setParentItem(self.plot_pdf)
        self._legend.hide()
        self._legend.anchor((1, 0), (1, 0))

    # -----------------------------
    # Event and signal handlers

    def update_views(self):
        for plot in (self.plot_pdf, self.plot_mark):
            plot.setGeometry(self.plot.sceneBoundingRect())
            plot.linkedViewChanged(self.plot, plot.XAxis)

    def onDeleteWidget(self):
        self.plot.clear()
        self.plot_pdf.clear()
        self.plot_mark.clear()
        super().onDeleteWidget()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.var = self.cvar = None
        self.data = data
        domain = self.data.domain if self.data else None
        varmodel = self.controls.var.model()
        cvarmodel = self.controls.cvar.model()
        varmodel.set_domain(domain)
        cvarmodel.set_domain(domain)
        if varmodel:
            self.var = varmodel[min(len(domain.class_vars), len(varmodel) - 1)]
        if domain is not None and domain.has_discrete_class:
            self.cvar = domain.class_var
        self.reset_select()
        self._user_var_bins.clear()
        self.openContext(domain)
        self.set_valid_data()
        self.recompute_binnings()
        self.replot()
        self.apply()

    def _on_var_changed(self):
        self.reset_select()
        self.set_valid_data()
        self.recompute_binnings()
        self.replot()
        self.apply()

    def _on_cvar_changed(self):
        self.set_valid_data()
        self.replot()
        self.apply()

    def _on_show_cumulative(self):
        self.replot()
        self.apply()

    def _on_sort_by_freq(self):
        self.replot()
        self.apply()

    def _on_bins_changed(self):
        self.reset_select()
        self._set_bin_width_slider_label()
        self.replot()
        # this is triggered when dragging, so don't call apply here;
        # apply is called on sliderReleased

    def _on_bin_slider_released(self):
        self._user_var_bins[self.var] = self.number_of_bins
        self.apply()

    def _on_fitted_dist_changed(self):
        self.controls.hide_bars.setDisabled(not self.fitted_distribution)
        self._set_smoothing_visibility()
        self.replot()

    def _on_hide_bars_changed(self):
        for bar in self.bar_items:  # pylint: disable=blacklisted-name
            bar.setHidden(self.hide_bars)
        self._set_curve_brushes()
        self.plot.update()

    def _set_smoothing_visibility(self):
        self.smoothing_box.setDisabled(
            self.Fitters[self.fitted_distribution][1] is not AshCurve)

    def _set_bin_width_slider_label(self):
        if self.number_of_bins < len(self.binnings):
            text = reduce(
                lambda s, rep: s.replace(*rep),
                short_time_units.items(),
                self.binnings[self.number_of_bins].width_label)
        else:
            text = ""
        self.bin_width_label.setText(text)

    def _on_show_probabilities_changed(self):
        label = self.controls.fitted_distribution.label
        if self.show_probs:
            label.setText("Fitted probability")
            label.setToolTip(
                "Chosen distribution is used to compute Bayesian probabilities")
        else:
            label.setText("Fitted distribution")
            label.setToolTip("")
        self.replot()

    @property
    def is_valid(self):
        return self.valid_data is not None

    def set_valid_data(self):
        err_def_var = self.Error.no_defined_values_var
        err_def_pair = self.Error.no_defined_values_pair
        err_def_var.clear()
        err_def_pair.clear()
        self.Warning.ignored_nans.clear()

        self.valid_data = self.valid_group_data = None
        if self.var is None:
            return

        column = self.data.get_column_view(self.var)[0].astype(float)
        valid_mask = np.isfinite(column)
        if not np.any(valid_mask):
            self.Error.no_defined_values_var(self.var.name)
            return
        if self.cvar:
            ccolumn = self.data.get_column_view(self.cvar)[0].astype(float)
            valid_mask *= np.isfinite(ccolumn)
            if not np.any(valid_mask):
                self.Error.no_defined_values_pair(self.var.name, self.cvar.name)
                return
            self.valid_group_data = ccolumn[valid_mask]
        if not np.all(valid_mask):
            self.Warning.ignored_nans()
        self.valid_data = column[valid_mask]

    # -----------------------------
    # Plotting

    def replot(self):
        self._clear_plot()
        if self.is_valid:
            self._set_axis_names()
            self._update_controls_state()
            self._call_plotting()
            self._display_legend()
        self.show_selection()

    def _clear_plot(self):
        self.plot.clear()
        self.plot_pdf.clear()
        self.plot_mark.clear()
        self.bar_items = []
        self.curve_items = []
        self._legend.clear()
        self._legend.hide()

    def _set_axis_names(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        bottomaxis = self.ploti.getAxis("bottom")
        bottomaxis.setLabel(self.var and self.var.name)
        bottomaxis.setShowUnit(not (self.var and self.var.is_time))

        leftaxis = self.ploti.getAxis("left")
        if self.show_probs and self.cvar:
            leftaxis.setLabel(
                f"Probability of '{self.cvar.name}' at given '{self.var.name}'")
        else:
            leftaxis.setLabel("Frequency")
        leftaxis.resizeEvent()

    def _update_controls_state(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        self.controls.sort_by_freq.setDisabled(self.var.is_continuous)
        self.continuous_box.setDisabled(self.var.is_discrete)
        self.controls.show_probs.setDisabled(self.cvar is None)
        self.controls.stacked_columns.setDisabled(self.cvar is None)

    def _call_plotting(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        self.curve_descriptions = None
        if self.var.is_discrete:
            if self.cvar:
                self._disc_split_plot()
            else:
                self._disc_plot()
        else:
            if self.cvar:
                self._cont_split_plot()
            else:
                self._cont_plot()
        self.plot.autoRange()

    def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded,
                 tooltip, desc, hidden=False):
        item = DistributionBarItem(
            x, width, padding, freqs, colors, stacked, expanded, tooltip,
            desc, hidden)
        self.plot.addItem(item)
        self.bar_items.append(item)

    def _disc_plot(self):
        var = self.var
        dist = distribution.get_distribution(self.data, self.var)
        dist = np.array(dist)  # Distribution misbehaves in further operations
        if self.sort_by_freq:
            order = np.argsort(dist)[::-1]
        else:
            order = np.arange(len(dist))

        ordered_values = np.array(var.values)[order]
        self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

        colors = [QColor(0, 128, 255)]
        for i, freq, desc in zip(count(), dist[order], ordered_values):
            tooltip = \
                "<p style='white-space:pre;'>" \
                f"<b>{escape(desc)}</b>: {int(freq)} " \
                f"({100 * freq / len(self.valid_data):.2f} %) "
            self._add_bar(
                i - 0.5, 1, 0.1, [freq], colors,
                stacked=False, expanded=False, tooltip=tooltip, desc=desc)

    def _disc_split_plot(self):
        var = self.var
        conts = contingency.get_contingency(self.data, self.cvar, self.var)
        conts = np.array(conts)  # Contingency misbehaves in further operations
        if self.sort_by_freq:
            order = np.argsort(conts.sum(axis=1))[::-1]
        else:
            order = np.arange(len(conts))

        ordered_values = np.array(var.values)[order]
        self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

        gcolors = [QColor(*col) for col in self.cvar.colors]
        gvalues = self.cvar.values
        total = len(self.data)
        for i, freqs, desc in zip(count(), conts[order], ordered_values):
            self._add_bar(
                i - 0.5, 1, 0.1, freqs, gcolors,
                stacked=self.stacked_columns, expanded=self.show_probs,
                tooltip=self._split_tooltip(
                    desc, np.sum(freqs), total, gvalues, freqs),
                desc=desc)

    def _cont_plot(self):
        self._set_cont_ticks()
        data = self.valid_data
        binning = self.binnings[self.number_of_bins]
        y, x = np.histogram(data, bins=binning.thresholds)
        total = len(data)
        colors = [QColor(0, 128, 255)]
        if self.fitted_distribution:
            colors[0] = colors[0].lighter(130)

        tot_freq = 0
        lasti = len(y) - 1
        width = np.min(x[1:] - x[:-1])
        unique = self.number_of_bins == 0 and binning.width is None
        xoff = -width / 2 if unique else 0
        for i, (x0, x1), freq in zip(count(), zip(x, x[1:]), y):
            tot_freq += freq
            desc = self.str_int(x0, x1, not i, i == lasti, unique)
            tooltip = \
                "<p style='white-space:pre;'>" \
                f"<b>{escape(desc)}</b>: " \
                f"{freq} ({100 * freq / total:.2f} %)</p>"
            bar_width = width if unique else x1 - x0
            self._add_bar(
                x0 + xoff, bar_width, 0,
                [tot_freq if self.cumulative_distr else freq],
                colors, stacked=False, expanded=False, tooltip=tooltip,
                desc=desc, hidden=self.hide_bars)

        if self.fitted_distribution:
            self._plot_approximations(
                x[0], x[-1], [self._fit_approximation(data)],
                [QColor(0, 0, 0)], (1,))

    def _cont_split_plot(self):
        self._set_cont_ticks()
        data = self.valid_data
        binning = self.binnings[self.number_of_bins]
        _, bins = np.histogram(data, bins=binning.thresholds)
        gvalues = self.cvar.values
        varcolors = [QColor(*col) for col in self.cvar.colors]
        if self.fitted_distribution:
            gcolors = [c.lighter(130) for c in varcolors]
        else:
            gcolors = varcolors
        nvalues = len(gvalues)
        ys = []
        fitters = []
        prior_sizes = []
        for val_idx in range(nvalues):
            group_data = data[self.valid_group_data == val_idx]
            prior_sizes.append(len(group_data))
            ys.append(np.histogram(group_data, bins)[0])
            if self.fitted_distribution:
                fitters.append(self._fit_approximation(group_data))
        total = len(data)
        prior_sizes = np.array(prior_sizes)
        tot_freqs = np.zeros(len(ys))

        lasti = len(ys[0]) - 1
        width = np.min(bins[1:] - bins[:-1])
        unique = self.number_of_bins == 0 and binning.width is None
        xoff = -width / 2 if unique else 0
        for i, x0, x1, freqs in zip(count(), bins, bins[1:], zip(*ys)):
            tot_freqs += freqs
            plotfreqs = tot_freqs.copy() if self.cumulative_distr else freqs
            desc = self.str_int(x0, x1, not i, i == lasti, unique)
            bar_width = width if unique else x1 - x0
            self._add_bar(
                x0 + xoff, bar_width, 0 if self.stacked_columns else 0.1,
                plotfreqs,
                gcolors, stacked=self.stacked_columns, expanded=self.show_probs,
                hidden=self.hide_bars,
                tooltip=self._split_tooltip(
                    desc, np.sum(plotfreqs), total, gvalues, plotfreqs),
                desc=desc)

        if fitters:
            self._plot_approximations(bins[0], bins[-1], fitters, varcolors,
                                      prior_sizes / len(data))

    def _set_cont_ticks(self):
        axis = self.ploti.getAxis("bottom")
        if self.var and self.var.is_time:
            binning = self.binnings[self.number_of_bins]
            labels = np.array(binning.short_labels)
            thresholds = np.array(binning.thresholds)
            lengths = np.array([len(lab) for lab in labels])
            slengths = set(lengths)
            if len(slengths) == 1:
                ticks = [list(zip(thresholds[::2], labels[::2])),
                         list(zip(thresholds[1::2], labels[1::2]))]
            else:
                ticks = []
                for length in sorted(slengths, reverse=True):
                    idxs = lengths == length
                    ticks.append(list(zip(thresholds[idxs], labels[idxs])))
            axis.setTicks(ticks)
        else:
            axis.setTicks(None)

    def _fit_approximation(self, y):
        def join_pars(pairs):
            strv = self.var.str_val
            return ", ".join(f"{sname}={strv(val)}" for sname, val in pairs)

        def str_params():
            s = join_pars(
                (sname, val) for sname, val in zip(str_names, fitted)
                if sname and sname[0] != "-")
            par = join_pars(
                (sname[1:], val) for sname, val in zip(str_names, fitted)
                if sname and sname[0] == "-")
            if par:
                s += f" ({par})"
            return s

        if not y.size:
            return None, None
        _, dist, names, str_names = self.Fitters[self.fitted_distribution]
        fitted = dist.fit(y)
        params = dict(zip(names, fitted))
        return partial(dist.pdf, **params), str_params()

    def _plot_approximations(self, x0, x1, fitters, colors, prior_probs):
        x = np.linspace(x0, x1, 100)
        ys = np.zeros((len(fitters), 100))
        self.curve_descriptions = [s for _, s in fitters]
        for y, (fitter, _) in zip(ys, fitters):
            if fitter is None:
                continue
            if self.Fitters[self.fitted_distribution][1] is AshCurve:
                y[:] = fitter(x, sigma=(22 - self.kde_smoothing) / 40)
            else:
                y[:] = fitter(x)
            if self.cumulative_distr:
                y[:] = np.cumsum(y)
        tots = np.sum(ys, axis=0)

        show_probs = self.show_probs and self.cvar is not None
        plot = self.ploti if show_probs else self.plot_pdf

        for y, prior_prob, color in zip(ys, prior_probs, colors):
            if not prior_prob:
                continue
            if show_probs:
                y_p = y * prior_prob
                tot = (y_p + (tots - y) * (1 - prior_prob))
                tot[tot == 0] = 1
                y = y_p / tot
            curve = pg.PlotCurveItem(
                x=x, y=y, fillLevel=0,
                pen=pg.mkPen(width=5, color=color),
                shadowPen=pg.mkPen(width=8, color=color.darker(120)))
            plot.addItem(curve)
            self.curve_items.append(curve)
        if not show_probs:
            self.plot_pdf.autoRange()
        self._set_curve_brushes()

    def _set_curve_brushes(self):
        for curve in self.curve_items:
            if self.hide_bars:
                color = curve.opts['pen'].color().lighter(160)
                color.setAlpha(128)
                curve.setBrush(pg.mkBrush(color))
            else:
                curve.setBrush(None)

    @staticmethod
    def _split_tooltip(valname, tot_group, total, gvalues, freqs):
        div_group = tot_group or 1
        cs = "white-space:pre; text-align: right;"
        s = f"style='{cs} padding-left: 1em'"
        snp = f"style='{cs}'"
        return f"<table style='border-collapse: collapse'>" \
               f"<tr><th {s}>{escape(valname)}:</th>" \
               f"<td {snp}><b>{int(tot_group)}</b></td>" \
               "<td/>" \
               f"<td {s}><b>{100 * tot_group / total:.2f} %</b></td></tr>" + \
               f"<tr><td/><td/><td {s}>(in group)</td><td {s}>(overall)</td>" \
               "</tr>" + \
               "".join(
                   "<tr>"
                   f"<th {s}>{value}:</th>"
                   f"<td {snp}><b>{int(freq)}</b></td>"
                   f"<td {s}>{100 * freq / div_group:.2f} %</td>"
                   f"<td {s}>{100 * freq / total:.2f} %</td>"
                   "</tr>"
                   for value, freq in zip(gvalues, freqs)) + \
               "</table>"

    def _display_legend(self):
        assert self.is_valid  # called only from replot, so assumes data is OK
        if self.cvar is None:
            if not self.curve_descriptions or not self.curve_descriptions[0]:
                self._legend.hide()
                return
            self._legend.addItem(
                pg.PlotCurveItem(pen=pg.mkPen(width=5, color=0.0)),
                self.curve_descriptions[0])
        else:
            cvar_values = self.cvar.values
            colors = [QColor(*col) for col in self.cvar.colors]
            descriptions = self.curve_descriptions or repeat(None)
            for color, name, desc in zip(colors, cvar_values, descriptions):
                self._legend.addItem(
                    ScatterPlotItem(pen=color, brush=color, size=10, shape="s"),
                    escape(name + (f" ({desc})" if desc else "")))
        self._legend.show()

    # -----------------------------
    # Bins

    def recompute_binnings(self):
        if self.is_valid and self.var.is_continuous:
            # binning is computed on valid var data, ignoring any cvar nans
            column = self.data.get_column_view(self.var)[0].astype(float)
            if np.any(np.isfinite(column)):
                if self.var.is_time:
                    self.binnings = time_binnings(column, min_unique=5)
                    self.bin_width_label.setFixedWidth(45)
                else:
                    self.binnings = decimal_binnings(
                        column, min_width=self.min_var_resolution(self.var),
                        add_unique=10, min_unique=5)
                    self.bin_width_label.setFixedWidth(35)
                max_bins = len(self.binnings) - 1
        else:
            self.binnings = []
            max_bins = 0

        self.controls.number_of_bins.setMaximum(max_bins)
        self.number_of_bins = min(
            max_bins, self._user_var_bins.get(self.var, self.number_of_bins))
        self._set_bin_width_slider_label()

    @staticmethod
    def min_var_resolution(var):
        # pylint: disable=unidiomatic-typecheck
        if type(var) is not ContinuousVariable:
            return 0
        return 10 ** -var.number_of_decimals

    def str_int(self, x0, x1, first, last, unique=False):
        var = self.var
        sx0, sx1 = var.repr_val(x0), var.repr_val(x1)
        if self.cumulative_distr:
            return f"{var.name} < {sx1}"
        elif first and last or unique:
            return f"{var.name} = {sx0}"
        elif first:
            return f"{var.name} < {sx1}"
        elif last:
            return f"{var.name} ≥ {sx0}"
        elif sx0 == sx1 or x1 - x0 <= self.min_var_resolution(var):
            return f"{var.name} = {sx0}"
        else:
            return f"{sx0} ≤ {var.name} < {sx1}"

    # -----------------------------
    # Selection

    def _on_item_clicked(self, item, modifiers, drag):
        def add_or_remove(idx, add):
            self.drag_operation = [self.DragRemove, self.DragAdd][add]
            if add:
                self.selection.add(idx)
            else:
                if idx in self.selection:
                    # This can be False when removing with dragging and the
                    # mouse crosses unselected items
                    self.selection.remove(idx)

        def add_range(add):
            if self.last_click_idx is None:
                add = True
                idx_range = {idx}
            else:
                from_idx, to_idx = sorted((self.last_click_idx, idx))
                idx_range = set(range(from_idx, to_idx + 1))
            self.drag_operation = [self.DragRemove, self.DragAdd][add]
            if add:
                self.selection |= idx_range
            else:
                self.selection -= idx_range

        self.key_operation = None
        if item is None:
            self.reset_select()
            return

        idx = self.bar_items.index(item)
        if drag:
            # Dragging has to add a range, otherwise fast dragging skips bars
            add_range(self.drag_operation == self.DragAdd)
        else:
            if modifiers & Qt.ShiftModifier:
                add_range(self.drag_operation == self.DragAdd)
            elif modifiers & Qt.ControlModifier:
                add_or_remove(idx, add=idx not in self.selection)
            else:
                if self.selection == {idx}:
                    # Clicking on a single selected bar  deselects it,
                    # but dragging from here will select
                    add_or_remove(idx, add=False)
                    self.drag_operation = self.DragAdd
                else:
                    self.selection.clear()
                    add_or_remove(idx, add=True)
        self.last_click_idx = idx

        self.show_selection()

    def _on_blank_clicked(self):
        self.reset_select()

    def reset_select(self):
        self.selection.clear()
        self.last_click_idx = None
        self.drag_operation = None
        self.key_operation = None
        self.show_selection()

    def _on_end_selecting(self):
        self.apply()

    def show_selection(self):
        self.plot_mark.clear()
        if not self.is_valid:  # though if it's not, selection is empty anyway
            return

        blue = QColor(Qt.blue)
        pen = QPen(QBrush(blue), 3)
        pen.setCosmetic(True)
        brush = QBrush(blue.lighter(190))

        for group in self.grouped_selection():
            group = list(group)
            left_idx, right_idx = group[0], group[-1]
            left_pad, right_pad = self._determine_padding(left_idx, right_idx)
            x0 = self.bar_items[left_idx].x0 - left_pad
            x1 = self.bar_items[right_idx].x1 + right_pad
            item = QGraphicsRectItem(x0, 0, x1 - x0, 1)
            item.setPen(pen)
            item.setBrush(brush)
            if self.var.is_continuous:
                valname = self.str_int(
                    x0, x1, not left_idx, right_idx == len(self.bar_items) - 1)
                inside = sum(np.sum(self.bar_items[i].freqs) for i in group)
                total = len(self.valid_data)
                item.setToolTip(
                    "<p style='white-space:pre;'>"
                    f"<b>{escape(valname)}</b>: "
                    f"{inside} ({100 * inside / total:.2f} %)")
            self.plot_mark.addItem(item)

    def _determine_padding(self, left_idx, right_idx):
        def _padding(i):
            return (self.bar_items[i + 1].x0 - self.bar_items[i].x1) / 2

        if len(self.bar_items) == 1:
            return 6, 6
        if left_idx == 0 and right_idx == len(self.bar_items) - 1:
            return (_padding(0), ) * 2

        if left_idx > 0:
            left_pad = _padding(left_idx - 1)
        if right_idx < len(self.bar_items) - 1:
            right_pad = _padding(right_idx)
        else:
            right_pad = left_pad
        if left_idx == 0:
            left_pad = right_pad
        return left_pad, right_pad

    def grouped_selection(self):
        return [[g[1] for g in group]
                for _, group in groupby(enumerate(sorted(self.selection)),
                                        key=lambda x: x[1] - x[0])]

    def keyPressEvent(self, e):
        def on_nothing_selected():
            if e.key() == Qt.Key_Left:
                self.last_click_idx = len(self.bar_items) - 1
            else:
                self.last_click_idx = 0
            self.selection.add(self.last_click_idx)

        def on_key_left():
            if e.modifiers() & Qt.ShiftModifier:
                if self.key_operation == Qt.Key_Right and first != last:
                    self.selection.remove(last)
                    self.last_click_idx = last - 1
                elif first:
                    self.key_operation = Qt.Key_Left
                    self.selection.add(first - 1)
                    self.last_click_idx = first - 1
            else:
                self.selection.clear()
                self.last_click_idx = max(first - 1, 0)
                self.selection.add(self.last_click_idx)

        def on_key_right():
            if e.modifiers() & Qt.ShiftModifier:
                if self.key_operation == Qt.Key_Left and first != last:
                    self.selection.remove(first)
                    self.last_click_idx = first + 1
                elif not self._is_last_bar(last):
                    self.key_operation = Qt.Key_Right
                    self.selection.add(last + 1)
                    self.last_click_idx = last + 1
            else:
                self.selection.clear()
                self.last_click_idx = min(last + 1, len(self.bar_items) - 1)
                self.selection.add(self.last_click_idx)

        if not self.is_valid or not self.bar_items \
                or e.key() not in (Qt.Key_Left, Qt.Key_Right):
            super().keyPressEvent(e)
            return

        prev_selection = self.selection.copy()
        if not self.selection:
            on_nothing_selected()
        else:
            first, last = min(self.selection), max(self.selection)
            if e.key() == Qt.Key_Left:
                on_key_left()
            else:
                on_key_right()

        if self.selection != prev_selection:
            self.drag_operation = self.DragAdd
            self.show_selection()
            self.apply()

    def keyReleaseEvent(self, ev):
        if ev.key() == Qt.Key_Shift:
            self.key_operation = None
        super().keyReleaseEvent(ev)


    # -----------------------------
    # Output

    def apply(self):
        data = self.data
        selected_data = annotated_data = histogram_data = None
        if self.is_valid:
            if self.var.is_discrete:
                group_indices, values = self._get_output_indices_disc()
            else:
                group_indices, values = self._get_output_indices_cont()
            selected = np.nonzero(group_indices)[0]
            if selected.size:
                selected_data = create_groups_table(
                    data, group_indices,
                    include_unselected=False, values=values)
            annotated_data = create_annotated_table(data, selected)
            if self.var.is_continuous:  # annotate with bins
                hist_indices, hist_values = self._get_histogram_indices()
                annotated_data = create_groups_table(
                    annotated_data, hist_indices, var_name="Bin", values=hist_values)
            histogram_data = self._get_histogram_table()

        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.histogram_data.send(histogram_data)

    def _get_output_indices_disc(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        for group_idx, val_idx in enumerate(self.selection, start=1):
            group_indices[col == val_idx] = group_idx
        values = [self.var.values[i] for i in self.selection]
        return group_indices, values

    def _get_output_indices_cont(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        values = []
        for group_idx, group in enumerate(self.grouped_selection(), start=1):
            x0 = x1 = None
            for bar_idx in group:
                minx, maxx, mask = self._get_cont_baritem_indices(col, bar_idx)
                if x0 is None:
                    x0 = minx
                x1 = maxx
                group_indices[mask] = group_idx
            # pylint: disable=undefined-loop-variable
            values.append(
                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
        return group_indices, values

    def _get_histogram_table(self):
        var_bin = DiscreteVariable("Bin", [bar.desc for bar in self.bar_items])
        var_freq = ContinuousVariable("Count")
        X = []
        if self.cvar:
            domain = Domain([var_bin, self.cvar, var_freq])
            for i, bar in enumerate(self.bar_items):
                for j, freq in enumerate(bar.freqs):
                    X.append([i, j, freq])
        else:
            domain = Domain([var_bin, var_freq])
            for i, bar in enumerate(self.bar_items):
                X.append([i, bar.freqs[0]])
        return Table.from_numpy(domain, X)

    def _get_histogram_indices(self):
        group_indices = np.zeros(len(self.data), dtype=np.int32)
        col = self.data.get_column_view(self.var)[0].astype(float)
        values = []
        for bar_idx in range(len(self.bar_items)):
            x0, x1, mask = self._get_cont_baritem_indices(col, bar_idx)
            group_indices[mask] = bar_idx + 1
            values.append(
                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
        return group_indices, values

    def _get_cont_baritem_indices(self, col, bar_idx):
        bar_item = self.bar_items[bar_idx]
        minx = bar_item.x0
        maxx = bar_item.x1 + (bar_idx == len(self.bar_items) - 1)
        with np.errstate(invalid="ignore"):
            return minx, maxx, (col >= minx) * (col < maxx)

    def _is_last_bar(self, idx):
        return idx == len(self.bar_items) - 1

    # -----------------------------
    # Report

    def get_widget_name_extension(self):
        return self.var

    def send_report(self):
        self.plotview.scene().setSceneRect(self.plotview.sceneRect())
        if not self.is_valid:
            return
        self.report_plot()
        if self.cumulative_distr:
            text = f"Cummulative distribution of '{self.var.name}'"
        else:
            text = f"Distribution of '{self.var.name}'"
        if self.cvar:
            text += f" with columns split by '{self.cvar.name}'"
        self.report_caption(text)
示例#27
0
class OWHeatMap(widget.OWWidget):
    name = "Heat Map"
    description = "Plot a data matrix heatmap."
    icon = "icons/Heatmap.svg"
    priority = 260
    keywords = []

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settings_version = 3

    settingsHandler = settings.DomainContextHandler()

    # Disable clustering for inputs bigger than this
    MaxClustering = 25000
    # Disable cluster leaf ordering for inputs bigger than this
    MaxOrderedClustering = 1000

    threshold_low = settings.Setting(0.0)
    threshold_high = settings.Setting(1.0)

    merge_kmeans = settings.Setting(False)
    merge_kmeans_k = settings.Setting(50)

    # Display column with averages
    averages: bool = settings.Setting(True)
    # Display legend
    legend: bool = settings.Setting(True)
    # Annotations
    #: text row annotation (row names)
    annotation_var = settings.ContextSetting(None)
    #: color row annotation
    annotation_color_var = settings.ContextSetting(None)
    # Discrete variable used to split that data/heatmaps (vertically)
    split_by_var = settings.ContextSetting(None)
    # Selected row/column clustering method (name)
    col_clustering_method: str = settings.Setting(Clustering.None_.name)
    row_clustering_method: str = settings.Setting(Clustering.None_.name)

    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
    column_label_pos: int = settings.Setting(1)
    selected_rows: List[int] = settings.Setting(None, schema_only=True)

    auto_commit = settings.Setting(True)

    graph_name = "scene"

    left_side_scrolling = True

    class Information(widget.OWWidget.Information):
        sampled = Msg("Data has been sampled")
        discrete_ignored = Msg("{} categorical feature{} ignored")
        row_clust = Msg("{}")
        col_clust = Msg("{}")
        sparse_densified = Msg("Showing this data may require a lot of memory")

    class Error(widget.OWWidget.Error):
        no_continuous = Msg("No numeric features")
        not_enough_features = Msg("Not enough features for column clustering")
        not_enough_instances = Msg("Not enough instances for clustering")
        not_enough_instances_k_means = Msg(
            "Not enough instances for k-means merging")
        not_enough_memory = Msg("Not enough memory to show this data")

    class Warning(widget.OWWidget.Warning):
        empty_clusters = Msg("Empty clusters were removed")

    def __init__(self):
        super().__init__()
        self.__pending_selection = self.selected_rows

        # A kingdom for a save_state/restore_state
        self.col_clustering = enum_get(Clustering, self.col_clustering_method,
                                       Clustering.None_)
        self.row_clustering = enum_get(Clustering, self.row_clustering_method,
                                       Clustering.None_)

        @self.settingsAboutToBePacked.connect
        def _():
            self.col_clustering_method = self.col_clustering.name
            self.row_clustering_method = self.row_clustering.name

        self.keep_aspect = False

        #: The original data with all features (retained to
        #: preserve the domain on the output)
        self.input_data = None
        #: The effective data striped of discrete features, and often
        #: merged using k-means
        self.data = None
        self.effective_data = None
        #: kmeans model used to merge rows of input_data
        self.kmeans_model = None
        #: merge indices derived from kmeans
        #: a list (len==k) of int ndarray where the i-th item contains
        #: the indices which merge the input_data into the heatmap row i
        self.merge_indices = None
        self.parts: Optional[Parts] = None
        self.__rows_cache = {}
        self.__columns_cache = {}

        # GUI definition
        colorbox = gui.vBox(self.controlArea, "Color")
        self.color_cb = gui.palette_combo_box(self.palette_name)
        self.color_cb.currentIndexChanged.connect(self.update_color_schema)
        colorbox.layout().addWidget(self.color_cb)

        form = QFormLayout(formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        lowslider = gui.hSlider(colorbox,
                                self,
                                "threshold_low",
                                minValue=0.0,
                                maxValue=1.0,
                                step=0.05,
                                ticks=True,
                                intOnly=False,
                                createLabel=False,
                                callback=self.update_lowslider)
        highslider = gui.hSlider(colorbox,
                                 self,
                                 "threshold_high",
                                 minValue=0.0,
                                 maxValue=1.0,
                                 step=0.05,
                                 ticks=True,
                                 intOnly=False,
                                 createLabel=False,
                                 callback=self.update_highslider)

        form.addRow("Low:", lowslider)
        form.addRow("High:", highslider)

        colorbox.layout().addLayout(form)

        mergebox = gui.vBox(
            self.controlArea,
            "Merge",
        )
        gui.checkBox(mergebox,
                     self,
                     "merge_kmeans",
                     "Merge by k-means",
                     callback=self.__update_row_clustering)
        ibox = gui.indentedBox(mergebox)
        gui.spin(ibox,
                 self,
                 "merge_kmeans_k",
                 minv=5,
                 maxv=500,
                 label="Clusters:",
                 keyboardTracking=False,
                 callbackOnReturn=True,
                 callback=self.update_merge)

        cluster_box = gui.vBox(self.controlArea, "Clustering")
        # Row clustering
        self.row_cluster_cb = cb = ComboBox(maximumContentsLength=14)
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.row_clustering, ClusteringRole)
        self.connect_control(
            "row_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_row_clustering(cb.itemData(idx, ClusteringRole))

        # Column clustering
        self.col_cluster_cb = cb = ComboBox(maximumContentsLength=14)
        cb.setModel(create_list_model(ClusteringModelData, self))
        cbselect(cb, self.col_clustering, ClusteringRole)
        self.connect_control(
            "col_clustering",
            lambda value, cb=cb: cbselect(cb, value, ClusteringRole))

        @cb.activated.connect
        def _(idx, cb=cb):
            self.set_col_clustering(cb.itemData(idx, ClusteringRole))

        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )
        form.addRow("Rows:", self.row_cluster_cb)
        form.addRow("Columns:", self.col_cluster_cb)
        cluster_box.layout().addLayout(form)
        box = gui.vBox(self.controlArea, "Split By")

        self.row_split_model = DomainModel(
            placeholder="(None)",
            valid_types=(Orange.data.DiscreteVariable, ),
            parent=self,
        )
        self.row_split_cb = cb = ComboBox(
            enabled=not self.merge_kmeans,
            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=14,
            toolTip="Split the heatmap vertically by a categorical column")
        self.row_split_cb.setModel(self.row_split_model)
        self.connect_control("split_by_var",
                             lambda value, cb=cb: cbselect(cb, value))
        self.connect_control("merge_kmeans", self.row_split_cb.setDisabled)
        self.split_by_var = None

        self.row_split_cb.activated.connect(self.__on_split_rows_activated)
        box.layout().addWidget(self.row_split_cb)

        box = gui.vBox(self.controlArea, 'Annotation && Legends')

        gui.checkBox(box,
                     self,
                     'legend',
                     'Show legend',
                     callback=self.update_legend)

        gui.checkBox(box,
                     self,
                     'averages',
                     'Stripes with averages',
                     callback=self.update_averages_stripe)
        annotbox = QGroupBox("Row Annotations", flat=True)
        form = QFormLayout(annotbox,
                           formAlignment=Qt.AlignLeft,
                           labelAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow)
        self.annotation_model = DomainModel(placeholder="(None)")
        self.annotation_text_cb = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength)
        self.annotation_text_cb.setModel(self.annotation_model)
        self.annotation_text_cb.activated.connect(self.set_annotation_var)
        self.connect_control("annotation_var", self.annotation_var_changed)

        self.row_side_color_model = DomainModel(
            order=(DomainModel.CLASSES, DomainModel.Separator,
                   DomainModel.METAS),
            placeholder="(None)",
            valid_types=DomainModel.PRIMITIVE,
            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
            parent=self,
        )
        self.row_side_color_cb = ComboBoxSearch(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            minimumContentsLength=12)
        self.row_side_color_cb.setModel(self.row_side_color_model)
        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
        self.connect_control("annotation_color_var",
                             self.annotation_color_var_changed)
        form.addRow("Text", self.annotation_text_cb)
        form.addRow("Color", self.row_side_color_cb)
        box.layout().addWidget(annotbox)
        posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
        posbox.setFlat(True)
        cb = gui.comboBox(posbox,
                          self,
                          "column_label_pos",
                          callback=self.update_column_annotations)
        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
        cb.setCurrentIndex(self.column_label_pos)
        gui.checkBox(self.controlArea,
                     self,
                     "keep_aspect",
                     "Keep aspect ratio",
                     box="Resize",
                     callback=self.__aspect_mode_changed)

        gui.rubber(self.controlArea)
        gui.auto_send(self.controlArea, self, "auto_commit")

        # Scene with heatmap
        class HeatmapScene(QGraphicsScene):
            widget: Optional[HeatmapGridWidget] = None

        self.scene = self.scene = HeatmapScene(parent=self)
        self.view = GraphicsView(
            self.scene,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
            widgetResizable=True,
        )
        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
        self.view.customContextMenuRequested.connect(
            self._on_view_context_menu)
        self.mainArea.layout().addWidget(self.view)
        self.selected_rows = []
        self.__font_inc = QAction("Increase Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+>"))
        self.__font_dec = QAction("Decrease Font",
                                  self,
                                  shortcut=QKeySequence("ctrl+<"))
        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
            apply_all([self.__font_inc, self.__font_dec],
                      lambda a: a.setShortcutVisibleInContextMenu(True))
        self.addActions([self.__font_inc, self.__font_dec])

    @property
    def center_palette(self):
        palette = self.color_cb.currentData()
        return bool(palette.flags & palette.Diverging)

    @property
    def _column_label_pos(self) -> HeatmapGridWidget.Position:
        return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole]

    def annotation_color_var_changed(self, value):
        cbselect(self.row_side_color_cb, value, Qt.EditRole)

    def annotation_var_changed(self, value):
        cbselect(self.annotation_text_cb, value, Qt.EditRole)

    def set_row_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.row_clustering != method:
            self.row_clustering = method
            cbselect(self.row_cluster_cb, method, ClusteringRole)
            self.__update_row_clustering()

    def set_col_clustering(self, method: Clustering) -> None:
        assert isinstance(method, Clustering)
        if self.col_clustering != method:
            self.col_clustering = method
            cbselect(self.col_cluster_cb, method, ClusteringRole)
            self.__update_column_clustering()

    def sizeHint(self) -> QSize:
        return super().sizeHint().expandedTo(QSize(900, 700))

    def color_palette(self):
        return self.color_cb.currentData().lookup_table()

    def color_map(self) -> GradientColorMap:
        return GradientColorMap(self.color_palette(),
                                (self.threshold_low, self.threshold_high),
                                0 if self.center_palette else None)

    def clear(self):
        self.data = None
        self.input_data = None
        self.effective_data = None
        self.kmeans_model = None
        self.merge_indices = None
        self.annotation_model.set_domain(None)
        self.annotation_var = None
        self.row_side_color_model.set_domain(None)
        self.annotation_color_var = None
        self.row_split_model.set_domain(None)
        self.split_by_var = None
        self.parts = None
        self.clear_scene()
        self.selected_rows = []
        self.__columns_cache.clear()
        self.__rows_cache.clear()
        self.__update_clustering_enable_state(None)

    def clear_scene(self):
        if self.scene.widget is not None:
            self.scene.widget.layoutDidActivate.disconnect(
                self.__on_layout_activate)
            self.scene.widget.selectionFinished.disconnect(
                self.on_selection_finished)
        self.scene.widget = None
        self.scene.clear()

        self.view.setSceneRect(QRectF())
        self.view.setHeaderSceneRect(QRectF())
        self.view.setFooterSceneRect(QRectF())

    @Inputs.data
    def set_dataset(self, data=None):
        """Set the input dataset to display."""
        self.closeContext()
        self.clear()
        self.clear_messages()

        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)

        if data is not None and not len(data):
            data = None

        if data is not None and sp.issparse(data.X):
            try:
                data = data.to_dense()
            except MemoryError:
                data = None
                self.Error.not_enough_memory()
            else:
                self.Information.sparse_densified()

        input_data = data

        # Data contains no attributes or meta attributes only
        if data is not None and len(data.domain.attributes) == 0:
            self.Error.no_continuous()
            input_data = data = None

        # Data contains some discrete attributes which must be filtered
        if data is not None and \
                any(var.is_discrete for var in data.domain.attributes):
            ndisc = sum(var.is_discrete for var in data.domain.attributes)
            data = data.transform(
                Domain([
                    var for var in data.domain.attributes if var.is_continuous
                ], data.domain.class_vars, data.domain.metas))
            if not data.domain.attributes:
                self.Error.no_continuous()
                input_data = data = None
            else:
                self.Information.discrete_ignored(ndisc,
                                                  "s" if ndisc > 1 else "")

        self.data = data
        self.input_data = input_data

        if data is not None:
            self.annotation_model.set_domain(self.input_data.domain)
            self.row_side_color_model.set_domain(self.input_data.domain)
            self.annotation_var = None
            self.annotation_color_var = None
            self.row_split_model.set_domain(data.domain)
            if data.domain.has_discrete_class:
                self.split_by_var = data.domain.class_var
            else:
                self.split_by_var = None
            self.openContext(self.input_data)
            if self.split_by_var not in self.row_split_model:
                self.split_by_var = None

        self.update_heatmaps()
        if data is not None and self.__pending_selection is not None:
            assert self.scene.widget is not None
            self.scene.widget.selectRows(self.__pending_selection)
            self.selected_rows = self.__pending_selection
            self.__pending_selection = None

        self.unconditional_commit()

    def __on_split_rows_activated(self):
        self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))

    def set_split_variable(self, var):
        if var != self.split_by_var:
            self.split_by_var = var
            self.update_heatmaps()

    def update_heatmaps(self):
        if self.data is not None:
            self.clear_scene()
            self.clear_messages()
            if self.col_clustering != Clustering.None_ and \
                    len(self.data.domain.attributes) < 2:
                self.Error.not_enough_features()
            elif (self.col_clustering != Clustering.None_ or
                  self.row_clustering != Clustering.None_) and \
                    len(self.data) < 2:
                self.Error.not_enough_instances()
            elif self.merge_kmeans and len(self.data) < 3:
                self.Error.not_enough_instances_k_means()
            else:
                parts = self.construct_heatmaps(self.data, self.split_by_var)
                self.construct_heatmaps_scene(parts, self.effective_data)
                self.selected_rows = []
        else:
            self.clear()

    def update_merge(self):
        self.kmeans_model = None
        self.merge_indices = None
        if self.data is not None and self.merge_kmeans:
            self.update_heatmaps()
            self.commit()

    def _make_parts(self, data, group_var=None):
        """
        Make initial `Parts` for data, split by group_var, group_key
        """
        if group_var is not None:
            assert group_var.is_discrete
            _col_data = table_column_data(data, group_var)
            row_indices = [
                np.flatnonzero(_col_data == i)
                for i in range(len(group_var.values))
            ]
            row_groups = [
                RowPart(title=name,
                        indices=ind,
                        cluster=None,
                        cluster_ordered=None)
                for name, ind in zip(group_var.values, row_indices)
            ]
        else:
            row_groups = [
                RowPart(title=None,
                        indices=range(0, len(data)),
                        cluster=None,
                        cluster_ordered=None)
            ]

        col_groups = [
            ColumnPart(title=None,
                       indices=range(0, len(data.domain.attributes)),
                       domain=data.domain,
                       cluster=None,
                       cluster_ordered=None)
        ]

        minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
        return Parts(row_groups, col_groups, span=(minv, maxv))

    def cluster_rows(self,
                     data: Table,
                     parts: 'Parts',
                     ordered=False) -> 'Parts':
        row_groups = []
        for row in parts.rows:
            if row.cluster is not None:
                cluster = row.cluster
            else:
                cluster = None
            if row.cluster_ordered is not None:
                cluster_ord = row.cluster_ordered
            else:
                cluster_ord = None

            if row.can_cluster:
                matrix = None
                need_dist = cluster is None or (ordered
                                                and cluster_ord is None)
                if need_dist:
                    subset = data[row.indices]
                    matrix = Orange.distance.Euclidean(subset)

                if cluster is None:
                    assert len(matrix) < self.MaxClustering
                    cluster = hierarchical.dist_matrix_clustering(
                        matrix, linkage=hierarchical.WARD)
                if ordered and cluster_ord is None:
                    assert len(matrix) < self.MaxOrderedClustering
                    cluster_ord = hierarchical.optimal_leaf_ordering(
                        cluster,
                        matrix,
                    )
            row_groups.append(
                row._replace(cluster=cluster, cluster_ordered=cluster_ord))

        return parts._replace(rows=row_groups)

    def cluster_columns(self, data, parts, ordered=False):
        assert len(parts.columns) == 1, "columns split is no longer supported"
        assert all(var.is_continuous for var in data.domain.attributes)

        col0 = parts.columns[0]
        if col0.cluster is not None:
            cluster = col0.cluster
        else:
            cluster = None
        if col0.cluster_ordered is not None:
            cluster_ord = col0.cluster_ordered
        else:
            cluster_ord = None
        need_dist = cluster is None or (ordered and cluster_ord is None)
        matrix = None
        if need_dist:
            data = Orange.distance._preprocess(data)
            matrix = np.asarray(Orange.distance.PearsonR(data, axis=0))
            # nan values break clustering below
            matrix = np.nan_to_num(matrix)

        if cluster is None:
            assert matrix is not None
            assert len(matrix) < self.MaxClustering
            cluster = hierarchical.dist_matrix_clustering(
                matrix, linkage=hierarchical.WARD)
        if ordered and cluster_ord is None:
            assert len(matrix) < self.MaxOrderedClustering
            cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix)

        col_groups = [
            col._replace(cluster=cluster, cluster_ordered=cluster_ord)
            for col in parts.columns
        ]
        return parts._replace(columns=col_groups)

    def construct_heatmaps(self, data, group_var=None) -> 'Parts':
        if self.merge_kmeans:
            if self.kmeans_model is None:
                effective_data = self.input_data.transform(
                    Orange.data.Domain([
                        var for var in self.input_data.domain.attributes
                        if var.is_continuous
                    ], self.input_data.domain.class_vars,
                                       self.input_data.domain.metas))
                nclust = min(self.merge_kmeans_k, len(effective_data) - 1)
                self.kmeans_model = kmeans_compress(effective_data, k=nclust)
                effective_data.domain = self.kmeans_model.domain
                merge_indices = [
                    np.flatnonzero(self.kmeans_model.labels == ind)
                    for ind in range(nclust)
                ]
                not_empty_indices = [
                    i for i, x in enumerate(merge_indices) if len(x) > 0
                ]
                self.merge_indices = \
                    [merge_indices[i] for i in not_empty_indices]
                if len(merge_indices) != len(self.merge_indices):
                    self.Warning.empty_clusters()
                effective_data = Orange.data.Table(
                    Orange.data.Domain(effective_data.domain.attributes),
                    self.kmeans_model.centroids[not_empty_indices])
            else:
                effective_data = self.effective_data

            group_var = None
        else:
            self.kmeans_model = None
            self.merge_indices = None
            effective_data = data

        self.effective_data = effective_data

        self.__update_clustering_enable_state(effective_data)

        parts = self._make_parts(effective_data, group_var)
        # Restore/update the row/columns items descriptions from cache if
        # available
        rows_cache_key = (group_var,
                          self.merge_kmeans_k if self.merge_kmeans else None)
        if rows_cache_key in self.__rows_cache:
            parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)

        if self.row_clustering != Clustering.None_:
            parts = self.cluster_rows(
                effective_data,
                parts,
                ordered=self.row_clustering == Clustering.OrderedClustering)
        if self.col_clustering != Clustering.None_:
            parts = self.cluster_columns(
                effective_data,
                parts,
                ordered=self.col_clustering == Clustering.OrderedClustering)

        # Cache the updated parts
        self.__rows_cache[rows_cache_key] = parts
        return parts

    def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None:
        _T = TypeVar("_T", bound=Union[RowPart, ColumnPart])

        def select_cluster(clustering: Clustering, item: _T) -> _T:
            if clustering == Clustering.None_:
                return item._replace(cluster=None, cluster_ordered=None)
            elif clustering == Clustering.Clustering:
                return item._replace(cluster=item.cluster,
                                     cluster_ordered=None)
            elif clustering == Clustering.OrderedClustering:
                return item._replace(cluster=item.cluster_ordered,
                                     cluster_ordered=None)
            else:  # pragma: no cover
                raise TypeError()

        rows = [
            select_cluster(self.row_clustering, rowitem)
            for rowitem in parts.rows
        ]
        cols = [
            select_cluster(self.col_clustering, colitem)
            for colitem in parts.columns
        ]
        parts = Parts(columns=cols, rows=rows, span=parts.span)

        self.setup_scene(parts, data)

    def setup_scene(self, parts, data):
        # type: (Parts, Table) -> None
        widget = HeatmapGridWidget()
        widget.setColorMap(self.color_map())
        self.scene.addItem(widget)
        self.scene.widget = widget
        columns = [v.name for v in data.domain.attributes]
        parts = HeatmapGridWidget.Parts(
            rows=[
                HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster)
                for r in parts.rows
            ],
            columns=[
                HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster)
                for c in parts.columns
            ],
            data=data.X,
            span=parts.span,
            row_names=None,
            col_names=columns,
        )
        widget.setHeatmaps(parts)
        side = self.row_side_colors()
        if side is not None:
            widget.setRowSideColorAnnotations(side[0],
                                              side[1],
                                              name=side[2].name)
        widget.setColumnLabelsPosition(self._column_label_pos)
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        widget.setShowAverages(self.averages)
        widget.setLegendVisible(self.legend)

        widget.layoutDidActivate.connect(self.__on_layout_activate)
        widget.selectionFinished.connect(self.on_selection_finished)

        self.update_annotations()
        self.view.setCentralWidget(widget)
        self.parts = parts

    def __update_scene_rects(self):
        widget = self.scene.widget
        if widget is None:
            return
        rect = widget.geometry()
        self.scene.setSceneRect(rect)
        self.view.setSceneRect(rect)
        self.view.setHeaderSceneRect(widget.headerGeometry())
        self.view.setFooterSceneRect(widget.footerGeometry())

    def __on_layout_activate(self):
        self.__update_scene_rects()

    def __aspect_mode_changed(self):
        widget = self.scene.widget
        if widget is None:
            return
        widget.setAspectRatioMode(
            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio)
        # when aspect fixed the vertical sh is fixex, when not, it can
        # shrink vertically
        sp = widget.sizePolicy()
        if self.keep_aspect:
            sp.setVerticalPolicy(QSizePolicy.Fixed)
        else:
            sp.setVerticalPolicy(QSizePolicy.Preferred)
        widget.setSizePolicy(sp)

    def __update_clustering_enable_state(self, data):
        if data is not None:
            N = len(data)
            M = len(data.domain.attributes)
        else:
            N = M = 0

        rc_enabled = N <= self.MaxClustering
        rco_enabled = N <= self.MaxOrderedClustering
        cc_enabled = M <= self.MaxClustering
        cco_enabled = M <= self.MaxOrderedClustering
        row_clust, col_clust = self.row_clustering, self.col_clustering

        row_clust_msg = ""
        col_clust_msg = ""

        if not rco_enabled and row_clust == Clustering.OrderedClustering:
            row_clust = Clustering.Clustering
            row_clust_msg = "Row cluster ordering was disabled due to the " \
                            "input matrix being to big"
        if not rc_enabled and row_clust == Clustering.Clustering:
            row_clust = Clustering.None_
            row_clust_msg = "Row clustering was was disabled due to the " \
                            "input matrix being to big"

        if not cco_enabled and col_clust == Clustering.OrderedClustering:
            col_clust = Clustering.Clustering
            col_clust_msg = "Column cluster ordering was disabled due to " \
                            "the input matrix being to big"
        if not cc_enabled and col_clust == Clustering.Clustering:
            col_clust = Clustering.None_
            col_clust_msg = "Column clustering was disabled due to the " \
                            "input matrix being to big"

        self.col_clustering = col_clust
        self.row_clustering = row_clust

        self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg))
        self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg))

        # Disable/enable the combobox items for the clustering methods
        def setenabled(cb: QComboBox, clu: bool, clu_op: bool):
            model = cb.model()
            assert isinstance(model, QStandardItemModel)
            idx = cb.findData(Clustering.OrderedClustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu_op)
            idx = cb.findData(Clustering.Clustering, ClusteringRole)
            assert idx != -1
            model.item(idx).setEnabled(clu)

        setenabled(self.row_cluster_cb, rc_enabled, rco_enabled)
        setenabled(self.col_cluster_cb, cc_enabled, cco_enabled)

    def update_averages_stripe(self):
        """Update the visibility of the averages stripe.
        """
        widget = self.scene.widget
        if widget is not None:
            widget.setShowAverages(self.averages)

    def update_lowslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            low.setSliderPosition(high.value() - 1)
        self.update_color_schema()

    def update_highslider(self):
        low, high = self.controls.threshold_low, self.controls.threshold_high
        if low.value() >= high.value():
            high.setSliderPosition(low.value() + 1)
        self.update_color_schema()

    def update_color_schema(self):
        self.palette_name = self.color_cb.currentData().name
        w = self.scene.widget
        if w is not None:
            w.setColorMap(self.color_map())

    def __update_column_clustering(self):
        self.update_heatmaps()
        self.commit()

    def __update_row_clustering(self):
        self.update_heatmaps()
        self.commit()

    def update_legend(self):
        widget = self.scene.widget
        if widget is not None:
            widget.setLegendVisible(self.legend)

    def row_annotation_var(self):
        return self.annotation_var

    def row_annotation_data(self):
        var = self.row_annotation_var()
        if var is None:
            return None
        return column_str_from_table(self.input_data, var)

    def _merge_row_indices(self):
        if self.merge_kmeans and self.kmeans_model is not None:
            return self.merge_indices
        else:
            return None

    def set_annotation_var(self, var: Union[None, Variable, int]):
        if isinstance(var, int):
            var = self.annotation_model[var]
        if self.annotation_var != var:
            self.annotation_var = var
            self.update_annotations()

    def update_annotations(self):
        widget = self.scene.widget
        if widget is not None:
            annot_col = self.row_annotation_data()
            merge_indices = self._merge_row_indices()
            if merge_indices is not None and annot_col is not None:
                join = lambda _1: join_elided(", ", 42, _1, " ({} more)")
                annot_col = aggregate_apply(join, annot_col, merge_indices)
            if annot_col is not None:
                widget.setRowLabels(annot_col)
                widget.setRowLabelsVisible(True)
            else:
                widget.setRowLabelsVisible(False)
                widget.setRowLabels(None)

    def row_side_colors(self):
        var = self.annotation_color_var
        if var is None:
            return None
        column_data = column_data_from_table(self.input_data, var)
        span = (np.nanmin(column_data), np.nanmax(column_data))
        merges = self._merge_row_indices()
        if merges is not None:
            column_data = aggregate(var, column_data, merges)
        data, colormap = self._colorize(var, column_data)
        if var.is_continuous:
            colormap.span = span
        return data, colormap, var

    def set_annotation_color_var(self, var: Union[None, Variable, int]):
        """Set the current side color annotation variable."""
        if isinstance(var, int):
            var = self.row_side_color_model[var]
        if self.annotation_color_var != var:
            self.annotation_color_var = var
            self.update_row_side_colors()

    def update_row_side_colors(self):
        widget = self.scene.widget
        if widget is None:
            return
        colors = self.row_side_colors()
        if colors is None:
            widget.setRowSideColorAnnotations(None)
        else:
            widget.setRowSideColorAnnotations(colors[0], colors[1],
                                              colors[2].name)

    def _colorize(self, var: Variable,
                  data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
        palette = var.palette  # type: Palette
        colors = np.array(
            [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
            dtype=np.uint8,
        )
        if var.is_discrete:
            mask = np.isnan(data)
            data[mask] = -1
            data = data.astype(int)
            if mask.any():
                values = (*var.values, "N/A")
            else:
                values = var.values
                colors = colors[:-1]
            return data, CategoricalColorMap(colors, values)
        elif var.is_continuous:
            cmap = GradientColorMap(colors[:-1])
            return data, cmap
        else:
            raise TypeError

    def update_column_annotations(self):
        widget = self.scene.widget
        if self.data is not None and widget is not None:
            widget.setColumnLabelsPosition(self._column_label_pos)

    def __adjust_font_size(self, diff):
        widget = self.scene.widget
        if widget is None:
            return
        curr = widget.font().pointSizeF()
        new = curr + diff

        self.__font_dec.setEnabled(new > 1.0)
        self.__font_inc.setEnabled(new <= 32)
        if new > 1.0:
            font = QFont()
            font.setPointSizeF(new)
            widget.setFont(font)

    def _on_view_context_menu(self, pos):
        widget = self.scene.widget
        if widget is None:
            return
        assert isinstance(widget, HeatmapGridWidget)
        menu = QMenu(self.view.viewport())
        menu.setAttribute(Qt.WA_DeleteOnClose)
        menu.addActions(self.view.actions())
        menu.addSeparator()
        menu.addActions([self.__font_inc, self.__font_dec])
        menu.addSeparator()
        a = QAction("Keep aspect ratio", menu, checkable=True)
        a.setChecked(self.keep_aspect)

        def ontoggled(state):
            self.keep_aspect = state
            self.__aspect_mode_changed()

        a.toggled.connect(ontoggled)
        menu.addAction(a)
        menu.popup(self.view.viewport().mapToGlobal(pos))

    def on_selection_finished(self):
        if self.scene.widget is not None:
            self.selected_rows = list(self.scene.widget.selectedRows())
        else:
            self.selected_rows = []
        self.commit()

    def commit(self):
        data = None
        indices = None
        if self.merge_kmeans:
            merge_indices = self.merge_indices
        else:
            merge_indices = None

        if self.input_data is not None and self.selected_rows:
            indices = self.selected_rows
            if merge_indices is not None:
                # expand merged indices
                indices = np.hstack([merge_indices[i] for i in indices])

            data = self.input_data[indices]

        self.Outputs.selected_data.send(data)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.input_data, indices))

    def onDeleteWidget(self):
        self.clear()
        super().onDeleteWidget()

    def send_report(self):
        self.report_items((
            ("Columns:",
             "Clustering" if self.col_clustering else "No sorting"),
            ("Rows:", "Clustering" if self.row_clustering else "No sorting"),
            ("Split:", self.split_by_var is not None
             and self.split_by_var.name),
            ("Row annotation", self.annotation_var is not None
             and self.annotation_var.name),
        ))
        self.report_plot()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version is not None and version < 3:

            def st2cl(state: bool) -> Clustering:
                return Clustering.OrderedClustering if state else \
                    Clustering.None_

            rc = settings.pop("row_clustering", False)
            cc = settings.pop("col_clustering", False)
            settings["row_clustering_method"] = st2cl(rc).name
            settings["col_clustering_method"] = st2cl(cc).name
示例#28
0
class OWMDS(OWWidget):
    name = "MDS"
    description = "Two-dimensional data projection by multidimensional " \
                  "scaling constructed from a distance matrix."
    icon = "icons/MDS.svg"

    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)
        distances = Input("Distances", Orange.misc.DistMatrix)
        data_subset = Input("Data Subset", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    settings_version = 2

    #: Initialization type
    PCA, Random = 0, 1

    #: Refresh rate
    RefreshRate = [("Every iteration", 1), ("Every 5 steps", 5),
                   ("Every 10 steps", 10), ("Every 25 steps", 25),
                   ("Every 50 steps", 50), ("None", -1)]

    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    settingsHandler = settings.DomainContextHandler()

    max_iter = settings.Setting(300)
    initialization = settings.Setting(PCA)
    refresh_rate = settings.Setting(3)

    # output embedding role.
    NoRole, AttrRole, AddAttrRole, MetaRole = 0, 1, 2, 3

    auto_commit = settings.Setting(True)

    selection_indices = settings.Setting(None, schema_only=True)

    #: Percentage of all pairs displayed (ranges from 0 to 20)
    connected_pairs = settings.Setting(5)

    legend_anchor = settings.Setting(((1, 0), (1, 0)))

    graph = SettingProvider(OWMDSGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Error(OWWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        matrix_too_small = Msg("Input matrix must be at least 2x2")
        no_attributes = Msg("Data has no attributes")
        mismatching_dimensions = \
            Msg("Data and distances dimensions do not match.")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")

    def __init__(self):
        super().__init__()
        #: Input dissimilarity matrix
        self.matrix = None  # type: Optional[Orange.misc.DistMatrix]
        #: Effective data used for plot styling/annotations. Can be from the
        #: input signal (`self.signal_data`) or the input matrix
        #: (`self.matrix.data`)
        self.data = None  # type: Optional[Orange.data.Table]
        #: Input subset data table
        self.subset_data = None  # type: Optional[Orange.data.Table]
        #: Data table from the `self.matrix.row_items` (if present)
        self.matrix_data = None  # type: Optional[Orange.data.Table]
        #: Input data table
        self.signal_data = None

        self._similar_pairs = None
        self._subset_mask = None  # type: Optional[np.ndarray]
        self._invalidated = False
        self.effective_matrix = None
        self._curve = None
        self._primitive_metas = ()
        self._data_metas = None

        self.variable_x = ContinuousVariable("mds-x")
        self.variable_y = ContinuousVariable("mds-y")

        self.__update_loop = None
        # timer for scheduling updates
        self.__timer = QTimer(self, singleShot=True, interval=0)
        self.__timer.timeout.connect(self.__next_step)
        self.__state = OWMDS.Waiting
        self.__in_next_step = False
        self.__draw_similar_pairs = False

        box = gui.vBox(self.controlArea, "MDS Optimization")
        form = QFormLayout(labelAlignment=Qt.AlignLeft,
                           formAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
                           verticalSpacing=10)

        form.addRow("Max iterations:",
                    gui.spin(box, self, "max_iter", 10, 10**4, step=1))

        form.addRow(
            "Initialization:",
            gui.radioButtons(box,
                             self,
                             "initialization",
                             btnLabels=("PCA (Torgerson)", "Random"),
                             callback=self.__invalidate_embedding))

        box.layout().addLayout(form)
        form.addRow(
            "Refresh:",
            gui.comboBox(box,
                         self,
                         "refresh_rate",
                         items=[t for t, _ in OWMDS.RefreshRate],
                         callback=self.__invalidate_refresh))
        gui.separator(box, 10)
        self.runbutton = gui.button(box,
                                    self,
                                    "Run",
                                    callback=self._toggle_run)

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWMDSGraph(self,
                                box,
                                "MDSGraph",
                                view_box=MDSInteractiveViewBox)
        box.layout().addWidget(self.graph.plot_widget)
        self.plot = self.graph.plot_widget

        g = self.graph.gui
        box = g.point_properties_box(self.controlArea)
        self.models = g.points_models

        gui.hSlider(box,
                    self,
                    "connected_pairs",
                    label="Show similar pairs:",
                    minValue=0,
                    maxValue=20,
                    createLabel=False,
                    callback=self._on_connected_changed)
        g.add_widgets(ids=[g.JitterSizeSlider], widget=box)

        box = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([
            g.ShowLegend, g.ToolTipShowsAll, g.ClassDensity,
            g.LabelOnlySelected
        ], box)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        palette = self.graph.plot_widget.palette()
        self.graph.set_palette(palette)

        gui.rubber(self.controlArea)

        self.graph.box_zoom_select(self.controlArea)

        gui.auto_commit(box,
                        self,
                        "auto_commit",
                        "Send Selected",
                        checkbox_label="Send selected automatically",
                        box=None)

        self.plot.getPlotItem().hideButtons()
        self.plot.setRenderHint(QPainter.Antialiasing)

        self.graph.jitter_continuous = True
        self._initialize()

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()
        self.connect_pairs()

    def update_colors(self):
        pass

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def init_attr_values(self):
        domain = self.data and len(self.data) and self.data.domain or None
        for model in self.models:
            model.set_domain(domain)
        self.graph.attr_color = self.data.domain.class_var if domain else None
        self.graph.attr_shape = None
        self.graph.attr_size = None
        self.graph.attr_label = None
        self.models[2][:] = self.models[2][0:1] + ["Stress"
                                                   ] + self.models[2][1:]

    def prepare_data(self):
        pass

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x, self.variable_y, True)

    def selection_changed(self):
        self.commit()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set the input data set.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        if data is not None and len(data) < 2:
            self.Error.not_enough_rows()
            data = None
        else:
            self.Error.not_enough_rows.clear()

        self.signal_data = data

        if self.matrix is not None and data is not None and len(
                self.matrix) == len(data):
            self.closeContext()
            self.data = data
            self.openContext(data)
        else:
            self._invalidated = True
        if data is not None:
            self._primitive_metas = tuple(a for a in data.domain.metas
                                          if a.is_primitive())
            keys = [
                k for k, a in enumerate(data.domain.metas) if a.is_primitive()
            ]
            self._data_metas = data.metas[:, keys]
        else:
            self._primitive_metas = ()
            self._data_metas = None

    @Inputs.distances
    def set_disimilarity(self, matrix):
        """Set the dissimilarity (distance) matrix.

        Parameters
        ----------
        matrix : Optional[Orange.misc.DistMatrix]
        """

        if matrix is not None and len(matrix) < 2:
            self.Error.matrix_too_small()
            matrix = None
        else:
            self.Error.matrix_too_small.clear()

        self.matrix = matrix
        if matrix is not None and matrix.row_items:
            self.matrix_data = matrix.row_items
        if matrix is None:
            self.matrix_data = None
        self._invalidated = True

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        """Set a subset of `data` input to highlight in the plot.

        Parameters
        ----------
        subset_data: Optional[Orange.data.Table]
        """
        self.subset_data = subset_data
        # invalidate the pen/brush when the subset is changed
        self._subset_mask = None  # type: Optional[np.ndarray]
        self.controls.graph.alpha_value.setEnabled(subset_data is None)
        self._invalidated = True

    def _clear(self):
        self._similar_pairs = None

        self.__set_update_loop(None)
        self.__state = OWMDS.Waiting

    def _clear_plot(self):
        self.graph.plot_widget.clear()

    def _initialize(self):
        # clear everything
        self.closeContext()
        self._clear()
        self.Error.clear()
        self.data = None
        self.effective_matrix = None
        self.embedding = None
        self.init_attr_values()

        # if no data nor matrix is present reset plot
        if self.signal_data is None and self.matrix is None:
            return

        if self.signal_data is not None and self.matrix is not None and \
                len(self.signal_data) != len(self.matrix):
            self.Error.mismatching_dimensions()
            self._update_plot()
            return

        if self.signal_data is not None:
            self.data = self.signal_data
        elif self.matrix_data is not None:
            self.data = self.matrix_data

        if self.matrix is not None:
            self.effective_matrix = self.matrix
            if self.matrix.axis == 0 and self.data is self.matrix_data:
                self.data = None
        elif self.data.domain.attributes:
            preprocessed_data = Orange.projection.MDS().preprocess(self.data)
            self.effective_matrix = Orange.distance.Euclidean(
                preprocessed_data)
        else:
            self.Error.no_attributes()
            return

        self.init_attr_values()
        self.openContext(self.data)

    def _toggle_run(self):
        if self.__state == OWMDS.Running:
            self.stop()
            self._invalidate_output()
        else:
            self.start()

    def start(self):
        if self.__state == OWMDS.Running:
            return
        elif self.__state == OWMDS.Finished:
            # Resume/continue from a previous run
            self.__start()
        elif self.__state == OWMDS.Waiting and \
                self.effective_matrix is not None:
            self.__start()

    def stop(self):
        if self.__state == OWMDS.Running:
            self.__set_update_loop(None)

    def __start(self):
        self.__draw_similar_pairs = False
        X = self.effective_matrix
        init = self.embedding

        # number of iterations per single GUI update step
        _, step_size = OWMDS.RefreshRate[self.refresh_rate]
        if step_size == -1:
            step_size = self.max_iter

        def update_loop(X, max_iter, step, init):
            """
            return an iterator over successive improved MDS point embeddings.
            """
            # NOTE: this code MUST NOT call into QApplication.processEvents
            done = False
            iterations_done = 0
            oldstress = np.finfo(np.float).max
            init_type = "PCA" if self.initialization == OWMDS.PCA else "random"

            while not done:
                step_iter = min(max_iter - iterations_done, step)
                mds = Orange.projection.MDS(dissimilarity="precomputed",
                                            n_components=2,
                                            n_init=1,
                                            max_iter=step_iter,
                                            init_type=init_type,
                                            init_data=init)

                mdsfit = mds(X)
                iterations_done += step_iter

                embedding, stress = mdsfit.embedding_, mdsfit.stress_
                stress /= np.sqrt(np.sum(embedding**2, axis=1)).sum()

                if iterations_done >= max_iter:
                    done = True
                elif (oldstress - stress) < mds.params["eps"]:
                    done = True
                init = embedding
                oldstress = stress

                yield embedding, mdsfit.stress_, iterations_done / max_iter

        self.__set_update_loop(update_loop(X, self.max_iter, step_size, init))
        self.progressBarInit(processEvents=None)

    def __set_update_loop(self, loop):
        """
        Set the update `loop` coroutine.

        The `loop` is a generator yielding `(embedding, stress, progress)`
        tuples where `embedding` is a `(N, 2) ndarray` of current updated
        MDS points, `stress` is the current stress and `progress` a float
        ratio (0 <= progress <= 1)

        If an existing update coroutine loop is already in palace it is
        interrupted (i.e. closed).

        .. note::
            The `loop` must not explicitly yield control flow to the event
            loop (i.e. call `QApplication.processEvents`)

        """
        if self.__update_loop is not None:
            self.__update_loop.close()
            self.__update_loop = None
            self.progressBarFinished(processEvents=None)

        self.__update_loop = loop

        if loop is not None:
            self.setBlocking(True)
            self.progressBarInit(processEvents=None)
            self.setStatusMessage("Running")
            self.runbutton.setText("Stop")
            self.__state = OWMDS.Running
            self.__timer.start()
        else:
            self.setBlocking(False)
            self.setStatusMessage("")
            self.runbutton.setText("Start")
            self.__state = OWMDS.Finished
            self.__timer.stop()

    def __next_step(self):
        if self.__update_loop is None:
            return

        assert not self.__in_next_step
        self.__in_next_step = True

        loop = self.__update_loop
        self.Error.out_of_memory.clear()
        try:
            embedding, _, progress = next(self.__update_loop)
            assert self.__update_loop is loop
        except StopIteration:
            self.__set_update_loop(None)
            self.unconditional_commit()
            self.__draw_similar_pairs = True
            self._update_plot()
        except MemoryError:
            self.Error.out_of_memory()
            self.__set_update_loop(None)
            self.__draw_similar_pairs = True
        except Exception as exc:
            self.Error.optimization_error(str(exc))
            self.__set_update_loop(None)
            self.__draw_similar_pairs = True
        else:
            self.progressBarSet(100.0 * progress, processEvents=None)
            self.embedding = embedding
            self._update_plot()
            # schedule next update
            self.__timer.start()

        self.__in_next_step = False

    def __invalidate_embedding(self):
        # reset/invalidate the MDS embedding, to the default initialization
        # (Random or PCA), restarting the optimization if necessary.
        if self.embedding is None:
            return
        state = self.__state
        if self.__update_loop is not None:
            self.__set_update_loop(None)

        X = self.effective_matrix

        if self.initialization == OWMDS.PCA:
            self.embedding = torgerson(X)
        else:
            self.embedding = np.random.rand(len(X), 2)

        self._update_plot()

        # restart the optimization if it was interrupted.
        if state == OWMDS.Running:
            self.__start()

    def __invalidate_refresh(self):
        state = self.__state

        if self.__update_loop is not None:
            self.__set_update_loop(None)

        # restart the optimization if it was interrupted.
        # TODO: decrease the max iteration count by the already
        # completed iterations count.
        if state == OWMDS.Running:
            self.__start()

    def handleNewSignals(self):
        if self._invalidated:
            self._invalidated = False
            self._initialize()
            self.start()
        self.__draw_similar_pairs = False

        if self._subset_mask is None and self.subset_data is not None and \
                self.data is not None:
            self._subset_mask = np.in1d(self.data.ids, self.subset_data.ids)

        self._update_plot(new=True)
        self.unconditional_commit()

    def _invalidate_output(self):
        self.commit()

    def _on_connected_changed(self):
        self._similar_pairs = None
        self._update_plot()

    def _update_plot(self, new=False):
        self._clear_plot()

        if self.embedding is not None:
            self._setup_plot(new=new)
        else:
            self.graph.new_data(None)

    def connect_pairs(self):
        if not (self.connected_pairs and self.__draw_similar_pairs):
            return
        emb_x, emb_y = self.graph.get_xy_data_positions(
            self.variable_x, self.variable_y, self.graph.valid_data)
        if self._similar_pairs is None:
            # This code requires storing lower triangle of X (n x n / 2
            # doubles), n x n / 2 * 2 indices to X, n x n / 2 indices for
            # argsort result. If this becomes an issue, it can be reduced to
            # n x n argsort indices by argsorting the entire X. Then we
            # take the first n + 2 * p indices. We compute their coordinates
            # i, j in the original matrix. We keep those for which i < j.
            # n + 2 * p will suffice to exclude the diagonal (i = j). If the
            # number of those for which i < j is smaller than p, we instead
            # take i > j. Among those that remain, we take the first p.
            # Assuming that MDS can't show so many points that memory could
            # become an issue, I preferred using simpler code.
            m = self.effective_matrix
            n = len(m)
            p = min(n * (n - 1) // 2 * self.connected_pairs // 100,
                    MAX_N_PAIRS * self.connected_pairs // 20)
            indcs = np.triu_indices(n, 1)
            sorted = np.argsort(m[indcs])[:p]
            self._similar_pairs = fpairs = np.empty(2 * p, dtype=int)
            fpairs[::2] = indcs[0][sorted]
            fpairs[1::2] = indcs[1][sorted]
        emb_x_pairs = emb_x[self._similar_pairs].reshape((-1, 2))
        emb_y_pairs = emb_y[self._similar_pairs].reshape((-1, 2))

        # Filter out zero distance lines (in embedding coords).
        # Null (zero length) line causes bad rendering artifacts
        # in Qt when using the raster graphics system (see gh-issue: 1668).
        (x1, x2), (y1, y2) = (emb_x_pairs.T, emb_y_pairs.T)
        pairs_mask = ~(np.isclose(x1, x2) & np.isclose(y1, y2))
        emb_x_pairs = emb_x_pairs[pairs_mask, :]
        emb_y_pairs = emb_y_pairs[pairs_mask, :]
        if self._curve:
            self.graph.plot_widget.removeItem(self._curve)
        self._curve = pg.PlotCurveItem(emb_x_pairs.ravel(),
                                       emb_y_pairs.ravel(),
                                       pen=pg.mkPen(0.8,
                                                    width=2,
                                                    cosmetic=True),
                                       connect="pairs",
                                       antialias=True)
        self.graph.plot_widget.addItem(self._curve)

    def _setup_plot(self, new=False):
        emb_x, emb_y = self.embedding[:, 0], self.embedding[:, 1]
        coords = np.vstack((emb_x, emb_y)).T
        attributes = self.data.domain.attributes + (self.variable_x, self.variable_y) + \
                     self._primitive_metas
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars)
        if self._data_metas is not None:
            data_x = (self.data.X, coords, self._data_metas)
        else:
            data_x = (self.data.X, coords)
        data = Table.from_numpy(domain, X=np.hstack(data_x), Y=self.data.Y)
        subset_data = data[
            self._subset_mask] if self._subset_mask is not None else None
        self.graph.new_data(data, subset_data=subset_data, new=new)
        self.graph.update_data(self.variable_x, self.variable_y, True)
        self.connect_pairs()

    def commit(self):
        if self.embedding is not None:
            names = get_unique_names([v.name for v in self.data.domain],
                                     ["mds-x", "mds-y"])
            output = embedding = Orange.data.Table.from_numpy(
                Orange.data.Domain([
                    ContinuousVariable(names[0]),
                    ContinuousVariable(names[1])
                ]), self.embedding)
        else:
            output = embedding = None

        if self.embedding is not None and self.data is not None:
            domain = self.data.domain
            domain = Orange.data.Domain(
                domain.attributes, domain.class_vars,
                domain.metas + embedding.domain.attributes)
            output = self.data.transform(domain)
            output.metas[:, -2:] = embedding.X

        selection = self.graph.get_selection()
        if output is not None and len(selection) > 0:
            selected = output[selection]
        else:
            selected = None
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(
            create_annotated_table(output, selection))

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._clear_plot()
        self._clear()

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_graph = {}
            for old, new in (("label_only_selected", "label_only_selected"),
                             ("symbol_opacity", "alpha_value"),
                             ("symbol_size", "point_width"), ("jitter",
                                                              "jitter_size")):
                settings_graph[new] = settings_[old]
            settings_["graph"] = settings_graph
            settings_["auto_commit"] = settings_["autocommit"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            n_domain = [t for t in context.ordered_domain if t[1] == 2]
            c_domain = [t for t in context.ordered_domain if t[1] == 1]
            context_values_graph = {}
            for _, old_val, new_val in ((domain, "color_value", "attr_color"),
                                        (c_domain, "shape_value",
                                         "attr_shape"),
                                        (n_domain, "size_value", "attr_size"),
                                        (domain, "label_value", "attr_label")):
                tmp = context.values[old_val]
                if tmp[1] >= 0:
                    context_values_graph[new_val] = (tmp[0], tmp[1] + 100)
                elif tmp[0] != "Stress":
                    context_values_graph[new_val] = None
                else:
                    context_values_graph[new_val] = tmp
            context.values["graph"] = context_values_graph
示例#29
0
class OWEditDomain(widget.OWWidget):
    name = "Edit Domain"
    description = "Rename features and their values."
    icon = "icons/EditDomain.svg"
    priority = 3125

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        data = Output("Data", Orange.data.Table)

    settingsHandler = settings.DomainContextHandler()

    domain_change_hints = settings.ContextSetting({})
    selected_index = settings.ContextSetting({})

    autocommit = settings.Setting(True)

    def __init__(self):
        super().__init__()

        self.data = None
        self.input_vars = ()
        self._invalidated = False

        box = gui.vBox(self.controlArea, "Domain Features")

        self.domain_model = itemmodels.VariableListModel()
        self.domain_view = QListView(selectionMode=QListView.SingleSelection)
        self.domain_view.setModel(self.domain_model)
        self.domain_view.selectionModel().selectionChanged.connect(
            self._on_selection_changed)
        box.layout().addWidget(self.domain_view)

        box = gui.hBox(self.controlArea)
        gui.button(box, self, "Reset Selected", callback=self.reset_selected)
        gui.button(box, self, "Reset All", callback=self.reset_all)

        gui.auto_commit(self.controlArea, self, "autocommit", "Apply")

        box = gui.vBox(self.mainArea, "Edit")
        self.editor_stack = QStackedWidget()

        self.editor_stack.addWidget(DiscreteVariableEditor())
        self.editor_stack.addWidget(ContinuousVariableEditor())
        self.editor_stack.addWidget(VariableEditor())

        box.layout().addWidget(self.editor_stack)

        self.Error.add_message("duplicate_var_name",
                               "A variable name is duplicated.")

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input data set."""
        self.closeContext()
        self.clear()
        self.data = data

        if self.data is not None:
            self._initialize()
            self.openContext(self.data)
            self._restore()

        self.unconditional_commit()

    def clear(self):
        """Clear the widget state."""
        self.data = None
        self.domain_model[:] = []
        self.input_vars = []
        self.domain_change_hints = {}
        self.selected_index = -1

    def reset_selected(self):
        """Reset the currently selected variable to its original state."""
        ind = self.selected_var_index()
        if ind >= 0:
            var = self.input_vars[ind]
            desc = variable_description(var, skip_attributes=True)
            if desc in self.domain_change_hints:
                del self.domain_change_hints[desc]

            self.domain_model[ind] = var
            self.editor_stack.currentWidget().set_data(var)
            self._invalidate()

    def reset_all(self):
        """Reset all variables to their original state."""
        self.domain_change_hints = {}
        if self.data is not None:
            # To invalidate stored hints
            self.domain_model[:] = self.input_vars
            itemmodels.select_row(self.domain_view, self.selected_index)
            self._invalidate()

    def selected_var_index(self):
        """Return the selected row in 'Domain Features' view."""
        rows = self.domain_view.selectedIndexes()
        assert len(rows) <= 1
        return rows[0].row() if rows else -1

    def _initialize(self):
        domain = self.data.domain
        self.input_vars = tuple(domain) + domain.metas
        self.domain_model[:] = list(self.input_vars)

    def _restore(self):
        # Restore the variable states from saved settings.
        def transform(var):
            vdesc = variable_description(var, skip_attributes=True)
            if vdesc in self.domain_change_hints:
                return variable_from_description(
                    self.domain_change_hints[vdesc],
                    compute_value=Orange.preprocess.transformation.Identity(
                        var))
            else:
                return var

        self.domain_model[:] = map(transform, self.input_vars)

        # Restore the variable selection if possible
        index = self.selected_index
        if index >= len(self.input_vars):
            index = 0 if len(self.input_vars) else -1
        if index >= 0:
            itemmodels.select_row(self.domain_view, index)

    def _on_selection_changed(self):
        self.selected_index = self.selected_var_index()
        self.open_editor(self.selected_index)

    def open_editor(self, index):
        self.clear_editor()
        if index < 0:
            return

        var = self.domain_model[index]

        editor_index = 2
        if var.is_discrete:
            editor_index = 0
        elif var.is_continuous:
            editor_index = 1
        editor = self.editor_stack.widget(editor_index)
        self.editor_stack.setCurrentWidget(editor)

        editor.set_data(var)
        editor.variable_changed.connect(self._on_variable_changed)

    def clear_editor(self):
        current = self.editor_stack.currentWidget()
        try:
            current.variable_changed.disconnect(self._on_variable_changed)
        except Exception:
            pass
        current.set_data(None)

    def _on_variable_changed(self):
        """User edited the current variable in editor."""
        assert 0 <= self.selected_index <= len(self.domain_model)
        editor = self.editor_stack.currentWidget()

        # Replace the variable in the 'Domain Features' view/model
        old_var = self.input_vars[self.selected_index]
        new_var = editor.get_data().copy(
            compute_value=Orange.preprocess.transformation.Identity(old_var))
        self.domain_model[self.selected_index] = new_var

        # Store the transformation hint.
        old_var_desc = variable_description(old_var, skip_attributes=True)
        self.domain_change_hints[old_var_desc] = variable_description(new_var)

        self._invalidate()

    def _invalidate(self):
        self.commit()

    def commit(self):
        """Send the changed data to output."""
        new_data = None
        var_names = [vn.name for vn in self.domain_model]
        self.Error.duplicate_var_name.clear()
        if self.data is not None:
            if len(var_names) == len(set(var_names)):
                input_domain = self.data.domain
                n_attrs = len(input_domain.attributes)
                n_class_vars = len(input_domain.class_vars)
                all_new_vars = list(self.domain_model)
                attrs = all_new_vars[:n_attrs]
                class_vars = all_new_vars[n_attrs:n_attrs + n_class_vars]
                new_metas = all_new_vars[n_attrs + n_class_vars:]
                new_domain = Orange.data.Domain(attrs, class_vars, new_metas)
                new_data = self.data.transform(new_domain)
            else:
                self.Error.duplicate_var_name()

        self.Outputs.data.send(new_data)

    def sizeHint(self):
        sh = super().sizeHint()
        return sh.expandedTo(QSize(660, 550))

    def send_report(self):
        if self.data is not None:
            self.report_raw(
                "",
                EditDomainReport(old_domain=chain(self.data.domain.variables,
                                                  self.data.domain.metas),
                                 new_domain=self.domain_model).to_html())
        else:
            self.report_data(None)
示例#30
0
class OWImpute(OWWidget):
    name = "Impute"
    description = "Impute missing values in the data table."
    icon = "icons/Impute.svg"
    priority = 2130

    inputs = [("Data", Orange.data.Table, "set_data"),
              ("Learner", Learner, "set_learner")]
    outputs = [("Data", Orange.data.Table)]

    class Error(OWWidget.Error):
        imputation_failed = Msg("Imputation failed for '{}'")

    DEFAULT_LEARNER = SimpleTreeLearner()
    METHODS = [
        AsDefault(),
        impute.DoNotImpute(),
        impute.Average(),
        impute.AsValue(),
        impute.Model(DEFAULT_LEARNER),
        impute.Random(),
        impute.DropInstances(),
        impute.Default()
    ]
    DEFAULT, DO_NOT_IMPUTE, MODEL_BASED_IMPUTER, AS_INPUT = 0, 1, 4, 7

    settingsHandler = settings.DomainContextHandler()

    _default_method_index = settings.Setting(DO_NOT_IMPUTE)
    variable_methods = settings.ContextSetting({})
    autocommit = settings.Setting(True)

    want_main_area = False
    resizing_enabled = False

    def __init__(self):
        super().__init__()
        # copy METHODS (some are modified by the widget)
        self.methods = copy.deepcopy(OWImpute.METHODS)

        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(10, 10, 10, 10)
        self.controlArea.layout().addLayout(main_layout)

        box = QGroupBox(title=self.tr("Default Method"), flat=False)
        box_layout = QVBoxLayout(box)
        main_layout.addWidget(box)

        button_group = QButtonGroup()
        button_group.buttonClicked[int].connect(self.set_default_method)
        for i, method in enumerate(self.methods):
            if not method.columns_only:
                button = QRadioButton(method.name)
                button.setChecked(i == self.default_method_index)
                button_group.addButton(button, i)
                box_layout.addWidget(button)

        self.default_button_group = button_group

        box = QGroupBox(title=self.tr("Individual Attribute Settings"),
                        flat=False)
        main_layout.addWidget(box)

        horizontal_layout = QHBoxLayout(box)
        main_layout.addWidget(box)

        self.varview = QListView(selectionMode=QListView.ExtendedSelection)
        self.varview.setItemDelegate(DisplayFormatDelegate())
        self.varmodel = itemmodels.VariableListModel()
        self.varview.setModel(self.varmodel)
        self.varview.selectionModel().selectionChanged.connect(
            self._on_var_selection_changed)
        self.selection = self.varview.selectionModel()

        horizontal_layout.addWidget(self.varview)

        method_layout = QVBoxLayout()
        horizontal_layout.addLayout(method_layout)

        button_group = QButtonGroup()
        for i, method in enumerate(self.methods):
            button = QRadioButton(text=method.name)
            button_group.addButton(button, i)
            method_layout.addWidget(button)

        self.value_combo = QComboBox(
            minimumContentsLength=8,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
            activated=self._on_value_selected)
        self.value_double = QDoubleSpinBox(
            editingFinished=self._on_value_selected,
            minimum=-1000.,
            maximum=1000.,
            singleStep=.1,
            decimals=3,
        )
        self.value_stack = value_stack = QStackedWidget()
        value_stack.addWidget(self.value_combo)
        value_stack.addWidget(self.value_double)
        method_layout.addWidget(value_stack)

        button_group.buttonClicked[int].connect(
            self.set_method_for_current_selection)

        method_layout.addStretch(2)

        reset_button = QPushButton("Restore All to Default",
                                   checked=False,
                                   checkable=False,
                                   clicked=self.reset_variable_methods,
                                   default=False,
                                   autoDefault=False)
        method_layout.addWidget(reset_button)

        self.variable_button_group = button_group

        box = gui.auto_commit(self.controlArea,
                              self,
                              "autocommit",
                              "Apply",
                              orientation=Qt.Horizontal,
                              checkbox_label="Apply automatically")
        box.layout().insertSpacing(0, 80)
        box.layout().insertWidget(0, self.report_button)

        self.data = None
        self.learner = None
        self.modified = False
        self.default_method = self.methods[self.default_method_index]

    @property
    def default_method_index(self):
        return self._default_method_index

    @default_method_index.setter
    def default_method_index(self, index):
        if self._default_method_index != index:
            self._default_method_index = index
            self.default_button_group.button(index).setChecked(True)
            self.default_method = self.methods[self.default_method_index]
            self.methods[self.DEFAULT].method = self.default_method

            # update variable view
            for index in map(self.varmodel.index, range(len(self.varmodel))):
                method = self.variable_methods.get(index.row(),
                                                   self.methods[self.DEFAULT])
                self.varmodel.setData(index, method, Qt.UserRole)
            self._invalidate()

    def set_default_method(self, index):
        """Set the current selected default imputation method.
        """
        self.default_method_index = index

    @check_sql_input
    def set_data(self, data):
        self.closeContext()
        self.varmodel[:] = []
        self.variable_methods = {}
        self.modified = False
        self.data = data

        if data is not None:
            self.varmodel[:] = data.domain.variables
            self.openContext(data.domain)

        self.update_varview()
        self.unconditional_commit()

    def set_learner(self, learner):
        self.learner = learner or self.DEFAULT_LEARNER
        imputer = self.methods[self.MODEL_BASED_IMPUTER]
        imputer.learner = self.learner

        button = self.default_button_group.button(self.MODEL_BASED_IMPUTER)
        button.setText(imputer.name)

        variable_button = self.variable_button_group.button(
            self.MODEL_BASED_IMPUTER)
        variable_button.setText(imputer.name)

        if learner is not None:
            self.default_method_index = self.MODEL_BASED_IMPUTER

        self.update_varview()
        self.commit()

    def get_method_for_column(self, column_index):
        """Returns the imputation method for column by its index.
        """
        if not isinstance(column_index, int):
            column_index = column_index.row()

        return self.variable_methods.get(column_index,
                                         self.methods[self.DEFAULT])

    def _invalidate(self):
        self.modified = True
        self.commit()

    def commit(self):
        data = self.data

        if self.data is not None:
            if not len(self.data):
                self.send("Data", self.data)
                self.modified = False
                return

            drop_mask = np.zeros(len(self.data), bool)

            attributes = []
            class_vars = []

            self.warning()
            self.Error.imputation_failed.clear()
            with self.progressBar(len(self.varmodel)) as progress:
                for i, var in enumerate(self.varmodel):
                    method = self.variable_methods.get(i, self.default_method)

                    try:
                        if not method.supports_variable(var):
                            self.warning(
                                "Default method can not handle '{}'".format(
                                    var.name))
                        elif isinstance(method, impute.DropInstances):
                            drop_mask |= method(self.data, var)
                        else:
                            var = method(self.data, var)
                    except Exception:  # pylint: disable=broad-except
                        self.Error.imputation_failed(var.name)
                        attributes = class_vars = None
                        break

                    if isinstance(var, Orange.data.Variable):
                        var = [var]

                    if i < len(self.data.domain.attributes):
                        attributes.extend(var)
                    else:
                        class_vars.extend(var)

                    progress.advance()

            if attributes is None:
                data = None
            else:
                domain = Orange.data.Domain(attributes, class_vars,
                                            self.data.domain.metas)
                data = self.data.from_table(domain, self.data[~drop_mask])

        self.send("Data", data)
        self.modified = False

    def send_report(self):
        specific = []
        for i, var in enumerate(self.varmodel):
            method = self.variable_methods.get(i, None)
            if method is not None:
                specific.append("{} ({})".format(var.name, str(method)))

        default = self.default_method.name
        if specific:
            self.report_items((("Default method", default),
                               ("Specific imputers", ", ".join(specific))))
        else:
            self.report_items((("Method", default), ))

    def _on_var_selection_changed(self):
        indexes = self.selection.selectedIndexes()
        methods = [self.get_method_for_column(i.row()) for i in indexes]

        def method_key(method):
            """
            Decompose method into its type and parameters.
            """
            # The return value should be hashable and  __eq__ comparable
            if isinstance(method, AsDefault):
                return AsDefault, (method.method, )
            elif isinstance(method, impute.Model):
                return impute.Model, (method.learner, )
            elif isinstance(method, impute.Default):
                return impute.Default, (method.default, )
            else:
                return type(method), None

        methods = set(method_key(m) for m in methods)
        selected_vars = [self.varmodel[index.row()] for index in indexes]
        has_discrete = any(var.is_discrete for var in selected_vars)
        fixed_value = None
        value_stack_enabled = False
        current_value_widget = None

        if len(methods) == 1:
            method_type, parameters = methods.pop()
            for i, m in enumerate(self.methods):
                if method_type == type(m):
                    self.variable_button_group.button(i).setChecked(True)

            if method_type is impute.Default:
                (fixed_value, ) = parameters

        elif self.variable_button_group.checkedButton() is not None:
            # Uncheck the current button
            self.variable_button_group.setExclusive(False)
            self.variable_button_group.checkedButton().setChecked(False)
            self.variable_button_group.setExclusive(True)
            assert self.variable_button_group.checkedButton() is None

        for method, button in zip(self.methods,
                                  self.variable_button_group.buttons()):
            enabled = all(
                method.supports_variable(var) for var in selected_vars)
            button.setEnabled(enabled)

        if not has_discrete:
            value_stack_enabled = True
            current_value_widget = self.value_double
        elif len(selected_vars) == 1:
            value_stack_enabled = True
            current_value_widget = self.value_combo
            self.value_combo.clear()
            self.value_combo.addItems(selected_vars[0].values)
        else:
            value_stack_enabled = False
            current_value_widget = None
            self.variable_button_group.button(self.AS_INPUT).setEnabled(False)

        self.value_stack.setEnabled(value_stack_enabled)
        if current_value_widget is not None:
            self.value_stack.setCurrentWidget(current_value_widget)
            if fixed_value is not None:
                if current_value_widget is self.value_combo:
                    self.value_combo.setCurrentIndex(fixed_value)
                elif current_value_widget is self.value_double:
                    self.value_double.setValue(fixed_value)
                else:
                    assert False

    def set_method_for_current_selection(self, method_index):
        indexes = self.selection.selectedIndexes()
        self.set_method_for_indexes(indexes, method_index)

    def set_method_for_indexes(self, indexes, method_index):
        if method_index == self.DEFAULT:
            for index in indexes:
                self.variable_methods.pop(index.row(), None)
        elif method_index == OWImpute.AS_INPUT:
            current = self.value_stack.currentWidget()
            if current is self.value_combo:
                value = self.value_combo.currentIndex()
            else:
                value = self.value_double.value()
            for index in indexes:
                method = impute.Default(default=value)
                self.variable_methods[index.row()] = method
        else:
            method = self.methods[method_index].copy()
            for index in indexes:
                self.variable_methods[index.row()] = method

        self.update_varview(indexes)
        self._invalidate()

    def update_varview(self, indexes=None):
        if indexes is None:
            indexes = map(self.varmodel.index, range(len(self.varmodel)))

        for index in indexes:
            self.varmodel.setData(index,
                                  self.get_method_for_column(index.row()),
                                  Qt.UserRole)

    def _on_value_selected(self):
        # The fixed 'Value' in the widget has been changed by the user.
        self.variable_button_group.button(self.AS_INPUT).setChecked(True)
        self.set_method_for_current_selection(self.AS_INPUT)

    def reset_variable_methods(self):
        indexes = list(map(self.varmodel.index, range(len(self.varmodel))))
        self.set_method_for_indexes(indexes, self.DEFAULT)
        self.variable_button_group.button(self.DEFAULT).setChecked(True)