Ejemplo n.º 1
0
    def test_create(self, SettingProvider):
        """:type SettingProvider: unittest.mock.Mock"""

        with patch.object(SettingsHandler, 'read_defaults'):
            handler = SettingsHandler.create(SimpleWidget)

            self.assertEqual(handler.widget_class, SimpleWidget)
            # create needs to create a SettingProvider which traverses
            # the widget definition and collects all settings and read
            # all settings and for widget class
            SettingProvider.assert_called_once_with(SimpleWidget)
            SettingsHandler.read_defaults.assert_called_once_with()
Ejemplo n.º 2
0
    def test_initialize_with_no_provider(self, SettingProvider):
        """:type SettingProvider: unittest.mock.Mock"""
        handler = SettingsHandler()
        handler.provider = Mock(get_provider=Mock(return_value=None))
        provider = Mock()
        SettingProvider.return_value = provider
        widget = SimpleWidget()

        # initializing an undeclared provider should display a warning
        with warnings.catch_warnings(record=True) as w:
            handler.initialize(widget)

            self.assertEqual(1, len(w))

        SettingProvider.assert_called_once_with(SimpleWidget)
        provider.initialize.assert_called_once_with(widget, None)
Ejemplo n.º 3
0
class BaseWidget:
    settingsHandler = None

    show_graph = Setting(True)

    graph = SettingProvider(Graph)

    def __init__(self):
        initialize_settings(self)
        self.graph = Graph()
class MockWidget(OWWidget):
    name = 'MockWidget'
    want_main_area = False

    filter_component = SettingProvider(CollapsibleFilterComponent)
    pagination_component = SettingProvider(PaginationComponent)

    pagination_availability = pyqtSignal(bool, bool)

    @patch(
        'orangecontrib.bioinformatics.widgets.components.resolwe.get_credential_manager',
        return_value=CredentialManager('resolwe_credentials_test'),
    )
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.filter_component = CollapsibleFilterComponent(
            self, self.controlArea)
        self.pagination_component = PaginationComponent(self, self.controlArea)
        self.sign_in_dialog = SignIn(self)
Ejemplo n.º 5
0
class Widget(BaseWidget):
    show_zoom_toolbar = Setting(True)

    zoom_toolbar = SettingProvider(ZoomToolbar)

    def __init__(self):
        super().__init__()
        initialize_settings(self)

        self.zoom_toolbar = ZoomToolbar()
Ejemplo n.º 6
0
class MyWidget(OWWidget):
    name = "Dummy"

    field = Setting(42)
    component = SettingProvider(DummyComponent)

    def __init__(self):
        super().__init__()

        self.component = DummyComponent(self)
        self.widget = None
class MockWidget(OWWidget):
    name = 'Mock'
    want_main_area = False
    selection_component = SettingProvider(GeneSetSelection)

    class Error(OWWidget.Error):
        custom_gene_sets_table_format = Msg('FooBar')

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.selection_component = GeneSetSelection(self, self.controlArea)
Ejemplo n.º 8
0
class MyWidget(OWWidget):
    foo = True

    component = SettingProvider(DummyComponent)

    def __init__(self):
        super().__init__()

        self.component = DummyComponent(self)
        self.foo_control = gui.checkBox(self.controlArea, self, "foo", "")
        self.component_foo_control = gui.checkBox(self.controlArea, self,
                                                  "component.foo", "")
Ejemplo n.º 9
0
class OWCurves(OWWidget):
    name = "Curves"
    inputs = [("Data", Orange.data.Table, 'set_data', Default),
              ("Data subset", Orange.data.Table, 'set_subset', Default)]
    outputs = [("Selection", Orange.data.Table), ("Data", Orange.data.Table)]
    icon = "icons/curves.svg"

    settingsHandler = DomainContextHandler(metas_in_res=True)

    curveplot = SettingProvider(CurvePlot)

    class Information(OWWidget.Information):
        showing_sample = Msg("Showing {} of {} curves.")

    class Warning(OWWidget.Warning):
        no_x = Msg("No continuous features in input data.")

    def __init__(self):
        super().__init__()
        self.controlArea.hide()
        self.curveplot = CurvePlot(self, select=SELECTMANY)
        self.mainArea.layout().addWidget(self.curveplot)
        self.resize(900, 700)
        self.graph_name = "curveplot.plotview"

    def set_data(self, data):
        self.Information.showing_sample.clear()
        self.Warning.no_x.clear()
        self.closeContext()
        self.curveplot.set_data(data)
        self.openContext(data)
        self.curveplot.update_view()
        if data is not None and not len(self.curveplot.data_x):
            self.Warning.no_x()
        if self.curveplot.sampled_indices \
                and len(self.curveplot.sampled_indices) != len(self.curveplot.data):
            self.Information.showing_sample(
                len(self.curveplot.sampled_indices), len(data))
        self.selection_changed()

    def set_subset(self, data):
        self.curveplot.set_data_subset(data.ids if data else None)

    def selection_changed(self):
        annotated = create_annotated_table(
            self.curveplot.data, sorted(self.curveplot.selected_indices))
        self.send("Data", annotated)
        selected = None
        if self.curveplot.selected_indices and self.curveplot.data:
            selected = self.curveplot.data[sorted(
                self.curveplot.selected_indices)]
        self.send("Selection", selected)
Ejemplo n.º 10
0
class SimpleWidget:
    settings_version = 1

    setting = Setting(42)
    schema_only_setting = Setting(None, schema_only=True)
    non_setting = 5

    component = SettingProvider(Component)

    def __init__(self):
        self.component = Component()

    migrate_settings = Mock()
    migrate_context = Mock()
Ejemplo n.º 11
0
class SimpleWidget(QObject):
    settings_version = 1

    setting = Setting(42)
    schema_only_setting = Setting(None, schema_only=True)
    list_setting = Setting([])
    non_setting = 5

    component = SettingProvider(Component)
    settingsAboutToBePacked = Signal()

    def __init__(self):
        super().__init__()
        self.component = Component()

    migrate_settings = Mock()
    migrate_context = Mock()
Ejemplo n.º 12
0
class OWManifoldLearning(OWWidget):
    name = "流形学习(Manifold Learning)"
    description = "非线性降维。"
    icon = "icons/Manifold.svg"
    priority = 2200
    keywords = []
    settings_version = 2

    class Inputs:
        data = Input("数据(Data)", Table, replaces=['Data'])

    class Outputs:
        transformed_data = Output("转换的数据(Transformed data)", Table, dynamic=False, replaces=['Transformed data'])

    MANIFOLD_METHODS = (TSNE, MDS, Isomap, LocallyLinearEmbedding,
                        SpectralEmbedding)

    tsne_editor = SettingProvider(TSNEParametersEditor)
    mds_editor = SettingProvider(MDSParametersEditor)
    isomap_editor = SettingProvider(IsomapParametersEditor)
    lle_editor = SettingProvider(LocallyLinearEmbeddingParametersEditor)
    spectral_editor = SettingProvider(SpectralEmbeddingParametersEditor)

    resizing_enabled = False
    want_main_area = False

    manifold_method_index = Setting(0)
    n_components = Setting(2)
    auto_apply = Setting(True)

    class Error(OWWidget.Error):
        n_neighbors_too_small = Msg("For chosen method and components, "
                                    "neighbors must be greater than {}")
        manifold_error = Msg("{}")
        sparse_not_supported = Msg("Sparse data is not supported.")
        out_of_memory = Msg("Out of memory")

    class Warning(OWWidget.Warning):
        graph_not_connected = Msg("Disconnected graph, embedding may not work")

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2:
            tsne_settings = settings.get('tsne_editor', {})
            # Fixup initialization index
            if 'init_index' in tsne_settings:
                idx = tsne_settings.pop('init_index')
                idx = min(idx, len(TSNEParametersEditor.initialization_values))
                tsne_settings['initialization_index'] = idx
            # We removed several metrics here
            if 'metric_index' in tsne_settings:
                idx = tsne_settings['metric_index']
                idx = min(idx, len(TSNEParametersEditor.metric_values))
                tsne_settings['metric_index'] = idx

    def __init__(self):
        self.data = None

        # GUI
        method_box = gui.vBox(self.controlArea, "方法")
        self.manifold_methods_combo = gui.comboBox(
            method_box, self, "manifold_method_index",
            items=[m.name for m in self.MANIFOLD_METHODS],
            callback=self.manifold_method_changed)

        self.params_box = gui.vBox(self.controlArea, "参数")

        self.tsne_editor = TSNEParametersEditor(self)
        self.mds_editor = MDSParametersEditor(self)
        self.isomap_editor = IsomapParametersEditor(self)
        self.lle_editor = LocallyLinearEmbeddingParametersEditor(self)
        self.spectral_editor = SpectralEmbeddingParametersEditor(self)
        self.parameter_editors = [
            self.tsne_editor, self.mds_editor, self.isomap_editor,
            self.lle_editor, self.spectral_editor]

        for editor in self.parameter_editors:
            self.params_box.layout().addWidget(editor)
            editor.hide()
        self.params_widget = self.parameter_editors[self.manifold_method_index]
        self.params_widget.show()

        output_box = gui.vBox(self.controlArea, "输出")
        self.n_components_spin = gui.spin(
            output_box, self, "n_components", 1, 10, label="成分:",
            alignment=Qt.AlignRight, callbackOnReturn=True,
            callback=self.settings_changed)
        self.apply_button = gui.auto_apply(self.controlArea, self, box=False, commit=self.apply)

    def manifold_method_changed(self):
        self.params_widget.hide()
        self.params_widget = self.parameter_editors[self.manifold_method_index]
        self.params_widget.show()
        self.apply()

    def settings_changed(self):
        self.apply()

    @Inputs.data
    def set_data(self, data):
        self.data = data
        self.n_components_spin.setMaximum(len(self.data.domain.attributes)
                                          if self.data else 10)
        self.unconditional_apply()

    def apply(self):
        builtin_warn = warnings.warn

        def _handle_disconnected_graph_warning(msg, *args, **kwargs):
            if msg.startswith("Graph is not fully connected"):
                self.Warning.graph_not_connected()
            else:
                builtin_warn(msg, *args, **kwargs)

        out = None
        data = self.data
        method = self.MANIFOLD_METHODS[self.manifold_method_index]
        have_data = data is not None and len(data)
        self.Error.clear()
        self.Warning.clear()

        if have_data and data.is_sparse():
            self.Error.sparse_not_supported()
        elif have_data:
            domain = Domain([ContinuousVariable("C{}".format(i))
                             for i in range(self.n_components)],
                            data.domain.class_vars,
                            data.domain.metas)
            try:
                warnings.warn = _handle_disconnected_graph_warning
                projector = method(**self.get_method_parameters(data, method))
                model = projector(data)
                if isinstance(model, TSNEModel):
                    out = model.embedding
                else:
                    X = model.embedding_
                    out = Table(domain, X, data.Y, data.metas)
            except ValueError as e:
                if e.args[0] == "for method='hessian', n_neighbors " \
                                "must be greater than [n_components" \
                                " * (n_components + 3) / 2]":
                    n = self.n_components * (self.n_components + 3) / 2
                    self.Error.n_neighbors_too_small("{}".format(n))
                else:
                    self.Error.manifold_error(e.args[0])
            except MemoryError:
                self.Error.out_of_memory()
            except np.linalg.linalg.LinAlgError as e:
                self.Error.manifold_error(str(e))
            finally:
                warnings.warn = builtin_warn

        self.Outputs.transformed_data.send(out)

    def get_method_parameters(self, data, method):
        parameters = dict(n_components=self.n_components)
        parameters.update(self.params_widget.get_parameters())
        return parameters

    def send_report(self):
        method = self.MANIFOLD_METHODS[self.manifold_method_index]
        self.report_items((("Method", method.name),))
        parameters = self.get_method_parameters(self.data, method)
        self.report_items("Method parameters", tuple(parameters.items()))
        if self.data:
            self.report_data("Data", self.data)
class OWKaplanMeier(OWWidget):
    name = 'Kaplan-Meier Plot'
    # TODO
    description = ''
    # TODO
    icon = ''
    priority = 0

    show_confidence_interval: bool
    show_confidence_interval = Setting(False)

    show_median_line: bool
    show_median_line = Setting(False)

    show_censored_data: bool
    show_censored_data = Setting(False)

    settingsHandler = PerfectDomainContextHandler()
    time_var = ContextSetting(None)
    event_var = ContextSetting(None)
    group_var: Optional[DiscreteVariable] = ContextSetting(None)

    graph = SettingProvider(KaplanMeierPlot)

    auto_commit: bool = Setting(False, schema_only=True)

    class Inputs:
        data = Input('Data', Table)

    class Outputs:
        selected_data = Output('Data', Table)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.data: Optional[Table] = None
        self.plot_curves = None

        time_var_model = DomainModel(valid_types=(ContinuousVariable, ))
        event_var_model = DomainModel(valid_types=DomainModel.PRIMITIVE)
        group_var_model = DomainModel(placeholder='(None)',
                                      valid_types=(DiscreteVariable, ))

        box = gui.vBox(self.controlArea, 'Time', margin=0)
        gui.comboBox(box,
                     self,
                     'time_var',
                     model=time_var_model,
                     callback=self.on_controls_changed)

        box = gui.vBox(self.controlArea, 'Event', margin=0)
        gui.comboBox(box,
                     self,
                     'event_var',
                     model=event_var_model,
                     callback=self.on_controls_changed)

        box = gui.vBox(self.controlArea, 'Group', margin=0)
        gui.comboBox(box,
                     self,
                     'group_var',
                     model=group_var_model,
                     callback=self.on_controls_changed)

        box = gui.vBox(self.controlArea, 'Display options')
        gui.checkBox(
            box,
            self,
            'show_confidence_interval',
            label='Confidence intervals',
            callback=self.on_display_option_changed,
        )

        gui.checkBox(
            box,
            self,
            'show_median_line',
            label='Median',
            callback=self.on_display_option_changed,
        )

        gui.checkBox(
            box,
            self,
            'show_censored_data',
            label='Censored data',
            callback=self.on_display_option_changed,
        )

        self.graph: KaplanMeierPlot = KaplanMeierPlot(parent=self)
        self.graph.selection_changed.connect(self.commit)
        self.mainArea.layout().addWidget(self.graph)

        plot_gui = OWPlotGUI(self)
        plot_gui.box_zoom_select(self.controlArea)

        gui.rubber(self.controlArea)

        self.commit_button = gui.auto_commit(self.controlArea,
                                             self,
                                             'auto_commit',
                                             '&Commit',
                                             box=False)

    @Inputs.data
    def set_data(self, data: Table):
        self.closeContext()
        if not data:
            return

        self.data = data
        self.controls.time_var.model().set_domain(data.domain)
        self.controls.event_var.model().set_domain(data.domain)
        self.controls.group_var.model().set_domain(data.domain)
        self.time_var = None
        self.event_var = None
        self.group_var = None
        self.graph.selection = {}
        self.openContext(data.domain)

        self.graph.curves = {
            curve_id: curve
            for curve_id, curve in enumerate(self.generate_plot_curves())
        }
        self.graph.update_plot(**self._get_plot_options())
        self.commit()

    def _get_plot_options(self):
        return {
            'confidence_interval': self.show_confidence_interval,
            'median': self.show_median_line,
            'censored': self.show_censored_data,
        }

    def on_display_option_changed(self) -> None:
        self.graph.update_plot(**self._get_plot_options())

    def on_controls_changed(self):
        if not self.data:
            return

        self.graph.curves = {
            curve_id: curve
            for curve_id, curve in enumerate(self.generate_plot_curves())
        }
        self.graph.clear_selection()
        self.graph.update_plot(**self._get_plot_options())
        self.commit()

    def _get_discrete_var_color(self, index: Optional[int]):
        if self.group_var is not None and index is not None:
            return list(self.group_var.colors[index])

    def generate_plot_curves(self) -> List[EstimatedFunctionCurve]:
        if self.time_var is None or self.event_var is None:
            return []

        time, _ = self.data.get_column_view(self.time_var)
        events, _ = self.data.get_column_view(self.event_var)

        # time = np.array([2.5, 4, 4, 5, 6, 6])
        # events = np.array([1, 1, 1, 1, 0, 0])

        if self.group_var:
            groups, _ = self.data.get_column_view(self.group_var)
            group_indexes = [
                index for index, _ in enumerate(self.group_var.values)
            ]
            colors = [
                self._get_discrete_var_color(index) for index in group_indexes
            ]
            masks = groups == np.reshape(group_indexes, (-1, 1))

            return [
                EstimatedFunctionCurve(time[mask],
                                       events[mask],
                                       color=color,
                                       label=label) for mask, color, label in
                zip(masks, colors, self.group_var.values) if mask.any()
            ]

        else:
            return [EstimatedFunctionCurve(time, events)]

    def commit(self):
        if not self.graph.selection:
            self.Outputs.selected_data.send(None)
            return

        time, _ = self.data.get_column_view(self.time_var)
        if self.group_var is None:
            time_interval = self.graph.selection[0].x
            start, end = time_interval[0], time_interval[-1]
            selection = np.argwhere((time >= start)
                                    & (time <= end)).reshape(-1).astype(int)
        else:
            selection = []
            group, _ = self.data.get_column_view(self.group_var)
            for group_id, time_interval in self.graph.selection.items():
                start, end = time_interval.x[0], time_interval.x[-1]
                selection += (np.argwhere((time >= start) & (time <= end) & (
                    group == group_id)).reshape(-1).astype(int).tolist())
            selection = sorted(selection)

        self.Outputs.selected_data.send(self.data[selection, :])

    def sizeHint(self):
        return QSize(1280, 620)
Ejemplo n.º 14
0
class OWRadviz(OWAnchorProjectionWidget):
    name = "Radviz"
    description = "显示Radviz投影"
    icon = "icons/Radviz.svg"
    priority = 241
    keywords = ["viz"]

    settings_version = 2

    selected_vars = ContextSetting([])
    vizrank = SettingProvider(RadvizVizRank)
    GRAPH_CLASS = OWRadvizGraph
    graph = SettingProvider(OWRadvizGraph)

    class Warning(OWAnchorProjectionWidget.Warning):
        invalid_embedding = widget.Msg("选择的特征没有投影")
        removed_vars = widget.Msg("不显示具有两个以上值的分类变量")

    def __init__(self):
        self.model_selected = VariableListModel(enable_dnd=True)
        self.model_selected.removed.connect(self.__model_selected_changed)
        self.model_other = VariableListModel(enable_dnd=True)

        self.vizrank, self.btn_vizrank = RadvizVizRank.add_vizrank(
            None, self, "Suggest features", self.vizrank_set_attrs)
        super().__init__()

    def _add_controls(self):
        self.variables_selection = VariablesSelection(self,
                                                      self.model_selected,
                                                      self.model_other,
                                                      self.controlArea)
        self.variables_selection.added.connect(self.__model_selected_changed)
        self.variables_selection.removed.connect(self.__model_selected_changed)
        self.variables_selection.add_remove.layout().addWidget(
            self.btn_vizrank)
        super()._add_controls()
        self.controlArea.layout().removeWidget(self.control_area_stretch)
        self.control_area_stretch.setParent(None)

    @property
    def primitive_variables(self):
        if self.data is None or self.data.domain is None:
            return []
        dom = self.data.domain
        return [
            v for v in chain(dom.variables, dom.metas)
            if v.is_continuous or v.is_discrete and len(v.values) == 2
        ]

    @property
    def effective_variables(self):
        return self.model_selected[:]

    def vizrank_set_attrs(self, *attrs):
        if not attrs:
            return
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [
            var for var in self.primitive_variables if var not in attrs
        ]
        self.__model_selected_changed()

    def __model_selected_changed(self):
        self.selected_vars = [(var.name, vartype(var))
                              for var in self.model_selected]
        self.init_projection()
        self.setup_plot()
        self.commit()

    def colors_changed(self):
        super().colors_changed()
        self._init_vizrank()

    def set_data(self, data):
        super().set_data(data)
        self._init_vizrank()
        self.init_projection()

    def use_context(self):
        self.model_selected.clear()
        self.model_other.clear()
        if self.data is not None and len(self.selected_vars):
            d, selected = self.data.domain, [v[0] for v in self.selected_vars]
            self.model_selected[:] = [d[name] for name in selected]
            self.model_other[:] = [
                d[attr.name] for attr in self.primitive_variables
                if attr.name not in selected
            ]
        elif self.data is not None:
            d, variables = self.data.domain, self.primitive_variables
            class_var = [variables.pop(variables.index(d.class_var))] \
                if d.class_var in variables else []
            self.model_selected[:] = variables[:5]
            self.model_other[:] = variables[5:] + class_var

    def _init_vizrank(self):
        is_enabled = self.data is not None and \
            len(self.primitive_variables) > 3 and \
            self.attr_color is not None and \
            not np.isnan(self.data.get_column_view(
                self.attr_color)[0].astype(float)).all() and \
            len(self.data[self.valid_data]) > 1 and \
            np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0)
        self.btn_vizrank.setEnabled(is_enabled)
        if is_enabled:
            self.vizrank.initialize()

    def check_data(self):
        super().check_data()
        if self.data is not None:
            domain = self.data.domain
            vars_ = chain(domain.variables, domain.metas)
            n_vars = sum(v.is_primitive() for v in vars_)
            if len(self.primitive_variables) < n_vars:
                self.Warning.removed_vars()

    def init_attr_values(self):
        super().init_attr_values()
        self.selected_vars = []

    def _manual_move(self, anchor_idx, x, y):
        angle = np.arctan2(y, x)
        super()._manual_move(anchor_idx, np.cos(angle), np.sin(angle))

    def _send_components_x(self):
        components_ = super()._send_components_x()
        angle = np.arctan2(*components_[::-1])
        return np.row_stack((components_, angle))

    def _send_components_metas(self):
        return np.vstack((super()._send_components_metas(), ["angle"]))

    def clear(self):
        super().clear()
        self.projector = RadViz()

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Ejemplo n.º 15
0
class OWScatterPlot(OWDataProjectionWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs(OWDataProjectionWidget.Inputs):
        features = Input("Features", AttributeList)

    class Outputs(OWDataProjectionWidget.Outputs):
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 5
    auto_sample = Setting(True)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    tooltip_shows_all = Setting(True)

    GRAPH_CLASS = OWScatterPlotGraph
    graph = SettingProvider(OWScatterPlotGraph)
    embedding_variables_names = None

    xy_changed_manually = Signal(Variable, Variable)

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points.")

    class Information(OWDataProjectionWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        self.attr_box: QGroupBox = None
        self.xy_model: DomainModel = None
        self.cb_attr_x: ComboBoxSearch = None
        self.cb_attr_y: ComboBoxSearch = None
        self.vizrank: ScatterPlotVizRank = None
        self.vizrank_button: QPushButton = None
        self.sampling: QGroupBox = None

        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        super().__init__()

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            self.graph_writers.append(w)

    def _add_controls(self):
        self._add_controls_axis()
        self._add_controls_sampling()
        super()._add_controls()
        self.gui.add_widget(self.gui.JitterNumericValues, self._effects_box)
        self.gui.add_widgets([
            self.gui.ShowGridLines, self.gui.ToolTipShowsAll,
            self.gui.RegressionLine
        ], self._plot_box)
        gui.checkBox(
            self._plot_box,
            self,
            value="graph.orthonormal_regression",
            label="Treat variables as independent",
            callback=self.graph.update_regression_line,
            tooltip=
            "If checked, fit line to group (minimize distance from points);\n"
            "otherwise fit y as a function of x (minimize vertical distances)",
            disabledBy=self.cb_reg_line)

    def _add_controls_axis(self):
        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              contentsLength=12,
                              searchable=True)
        self.attr_box = gui.vBox(self.controlArea,
                                 'Axes',
                                 spacing=2 if gui.is_macstyle() else 8)
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(
            self.attr_box,
            self,
            "attr_x",
            label="Axis x:",
            callback=self.set_attr_from_combo,
            model=self.xy_model,
            **common_options,
        )
        self.cb_attr_y = gui.comboBox(
            self.attr_box,
            self,
            "attr_y",
            label="Axis y:",
            callback=self.set_attr_from_combo,
            model=self.xy_model,
            **common_options,
        )
        vizrank_box = gui.hBox(self.attr_box)
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

    def _add_controls_sampling(self):
        self.sampling = gui.auto_commit(self.controlArea,
                                        self,
                                        "auto_sample",
                                        "Sample",
                                        box="Sampling",
                                        callback=self.switch_sampling,
                                        commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

    @property
    def effective_variables(self):
        return [self.attr_x, self.attr_y
                ] if self.attr_x and self.attr_y else []

    @property
    def effective_data(self):
        eff_var = self.effective_variables
        if eff_var and self.attr_x.name == self.attr_y.name:
            eff_var = [self.attr_x]
        return self.data.transform(Domain(eff_var))

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        err_msg = ""
        if self.data is None:
            err_msg = "No data on input"
        elif self.data.is_sparse():
            err_msg = "Data is sparse"
        elif len(self.xy_model) < 3:
            err_msg = "Not enough features for ranking"
        elif self.attr_color is None:
            err_msg = "Color variable is not selected"
        elif np.isnan(
                self.data.get_column_view(
                    self.attr_color)[0].astype(float)).all():
            err_msg = "Color variable has no values"
        self.vizrank_button.setEnabled(not err_msg)
        self.vizrank_button.setToolTip(err_msg)

    def set_data(self, data):
        super().set_data(data)
        self._vizrank_color_change()

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.attr_label, str):
            self.attr_label = findvar(self.attr_label, self.gui.label_model)
        if isinstance(self.attr_color, str):
            self.attr_color = findvar(self.attr_color, self.gui.color_model)
        if isinstance(self.attr_shape, str):
            self.attr_shape = findvar(self.attr_shape, self.gui.shape_model)
        if isinstance(self.attr_size, str):
            self.attr_size = findvar(self.attr_size, self.gui.size_model)

    def check_data(self):
        super().check_data()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(self.data, SqlTable):
            if self.data.approx_len() < 4000:
                self.data = Table(self.data)
            else:
                self.Information.sampled_sql()
                self.sql_data = self.data
                data_sample = self.data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                self.data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain.variables) == 0):
            self.data = None

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        x_data = self.get_column(self.attr_x, filter_valid=False)
        y_data = self.get_column(self.attr_y, filter_valid=False)
        if x_data is None or y_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_x.name, self.attr_y.name)
        return np.vstack((x_data, y_data)).T

    # Tooltip
    def _point_tooltip(self, point_id, skip_attrs=()):
        point_data = self.data[point_id]
        xy_attrs = (self.attr_x, self.attr_y)
        text = "<br/>".join(
            escape('{} = {}'.format(var.name, point_data[var]))
            for var in xy_attrs)
        if self.tooltip_shows_all:
            others = super()._point_tooltip(point_id, skip_attrs=xy_attrs)
            if others:
                text = "<b>{}</b><br/><br/>{}".format(text, others)
        return text

    def can_draw_regresssion_line(self):
        return self.data is not None and \
               self.data.domain is not None and \
               self.attr_x.is_continuous and \
               self.attr_y.is_continuous

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            self.__timer.stop()
            return
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def init_attr_values(self):
        super().init_attr_values()
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        super().set_subset_data(subset_data)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.attr_box.setEnabled(True)
        self.vizrank.setEnabled(True)
        if self.attribute_selection_list and self.data is not None and \
                self.data.domain is not None and \
                all(attr in self.data.domain for attr
                        in self.attribute_selection_list):
            self.attr_x, self.attr_y = self.attribute_selection_list[:2]
            self.attr_box.setEnabled(False)
            self.vizrank.setEnabled(False)
        super().handleNewSignals()
        if self._domain_invalidated:
            self.graph.update_axes()
            self._domain_invalidated = False
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
            self._invalidated = self._invalidated \
                or self.attr_x != attributes[0] \
                or self.attr_y != attributes[1]
        else:
            self.attribute_selection_list = None

    def set_attr(self, attr_x, attr_y):
        if attr_x != self.attr_x or attr_y != self.attr_y:
            self.attr_x, self.attr_y = attr_x, attr_y
            self.attr_changed()

    def set_attr_from_combo(self):
        self.attr_changed()
        self.xy_changed_manually.emit(self.attr_x, self.attr_y)

    def attr_changed(self):
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())
        self.setup_plot()
        self.commit.deferred()

    def get_axes(self):
        return {"bottom": self.attr_x, "left": self.attr_y}

    def colors_changed(self):
        super().colors_changed()
        self._vizrank_color_change()

    @gui.deferred
    def commit(self):
        super().commit()
        self.send_features()

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(AttributeList(features) or None)

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
        return None

    def _get_send_report_caption(self):
        return report.render_items_vert(
            (("Color", self._get_caption_var_name(self.attr_color)),
             ("Label", self._get_caption_var_name(self.attr_label)),
             ("Shape", self._get_caption_var_name(self.attr_shape)),
             ("Size", self._get_caption_var_name(self.attr_size)),
             ("Jittering", (self.attr_x.is_discrete or self.attr_y.is_discrete
                            or self.graph.jitter_continuous)
              and self.graph.jitter_size)))

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1)
                                           for a in settings["selection"]]
        if version < 3:
            if "auto_send_selection" in settings:
                settings["auto_commit"] = settings["auto_send_selection"]
            if "selection_group" in settings:
                settings["selection"] = settings["selection_group"]
        if version < 5:
            if "graph" in settings and \
                    "jitter_continuous" not in settings["graph"]:
                settings["graph"]["jitter_continuous"] = True

    @classmethod
    def migrate_context(cls, context, version):
        values = context.values
        if version < 3:
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
        if version < 4:
            if values["attr_x"][1] % 100 == 1 or values["attr_y"][1] % 100 == 1:
                raise IncompatibleContext()
Ejemplo n.º 16
0
class OWDataProjectionWidget(OWProjectionWidgetBase, openclass=True):
    """
    Base widget for widgets that get Data and Data Subset (both
    Orange.data.Table) on the input, and output Selected Data and Data
    (both Orange.data.Table).

    Beside that the widget displays data as two-dimensional projection
    of points.
    """
    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    class Warning(OWProjectionWidgetBase.Warning):
        too_many_labels = Msg(
            "Too many labels to show (zoom in or label only selected)")
        subset_not_subset = Msg(
            "Subset data contains some instances that do not appear in "
            "input data")
        subset_independent = Msg(
            "No subset data instances appear in input data")
        transparent_subset = Msg(
            "Increase opacity if subset is difficult to see")

    settingsHandler = DomainContextHandler()
    selection = Setting(None, schema_only=True)
    visual_settings = Setting({}, schema_only=True)
    auto_commit = Setting(True)

    GRAPH_CLASS = OWScatterPlotBase
    graph = SettingProvider(OWScatterPlotBase)
    graph_name = "graph.plot_widget.plotItem"
    embedding_variables_names = ("proj-x", "proj-y")
    buttons_area_orientation = Qt.Vertical

    input_changed = Signal(object)
    output_changed = Signal(object)

    def __init__(self):
        super().__init__()
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self._invalidated = True
        self._domain_invalidated = True
        self.setup_gui()
        VisualSettingsDialog(self,
                             self.graph.parameter_setter.initial_settings)

    # GUI
    def setup_gui(self):
        self._add_graph()
        self._add_controls()
        self._add_buttons()
        self.input_changed.emit(None)
        self.output_changed.emit(None)

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = self.GRAPH_CLASS(self, box)
        box.layout().addWidget(self.graph.plot_widget)
        self.graph.too_many_labels.connect(
            lambda too_many: self.Warning.too_many_labels(shown=too_many))

    def _add_controls(self):
        self.gui = OWPlotGUI(self)
        area = self.controlArea
        self._point_box = self.gui.point_properties_box(area)
        self._effects_box = self.gui.effects_box(area)
        self._plot_box = self.gui.plot_properties_box(area)

    def _add_buttons(self):
        gui.rubber(self.controlArea)
        self.gui.box_zoom_select(self.buttonsArea)
        gui.auto_send(self.buttonsArea, self, "auto_commit")

    @property
    def effective_variables(self):
        return self.data.domain.attributes

    @property
    def effective_data(self):
        return self.data.transform(
            Domain(self.effective_variables, self.data.domain.class_vars,
                   self.data.domain.metas))

    # Input
    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        data_existed = self.data is not None
        effective_data = self.effective_data if data_existed else None
        self.closeContext()
        self.data = data
        self.check_data()
        self.init_attr_values()
        self.openContext(self.data)
        self._invalidated = not (data_existed
                                 and self.data is not None and array_equal(
                                     effective_data.X, self.effective_data.X))
        self._domain_invalidated = not (
            data_existed and self.data is not None
            and effective_data.domain.checksum()
            == self.effective_data.domain.checksum())
        if self._invalidated:
            self.clear()
            self.input_changed.emit(data)
        self.enable_controls()

    def check_data(self):
        self.clear_messages()

    def enable_controls(self):
        self.cb_class_density.setEnabled(self.can_draw_density())

    @Inputs.data_subset
    @check_sql_input
    def set_subset_data(self, subset):
        self.subset_data = subset

    def handleNewSignals(self):
        self._handle_subset_data()
        if self._invalidated:
            self._invalidated = False
            self.setup_plot()
        else:
            self.graph.update_point_props()
        self._update_opacity_warning()
        self.unconditional_commit()

    def _handle_subset_data(self):
        self.Warning.subset_independent.clear()
        self.Warning.subset_not_subset.clear()
        if self.data is None or self.subset_data is None:
            self.subset_indices = set()
        else:
            self.subset_indices = set(self.subset_data.ids)
            ids = set(self.data.ids)
            if not self.subset_indices & ids:
                self.Warning.subset_independent()
            elif self.subset_indices - ids:
                self.Warning.subset_not_subset()

    def _update_opacity_warning(self):
        self.Warning.transparent_subset(
            shown=self.subset_indices and self.graph.alpha_value < 128)

    def get_subset_mask(self):
        if not self.subset_indices:
            return None
        valid_data = self.data[self.valid_data]
        return np.fromiter((ex.id in self.subset_indices for ex in valid_data),
                           dtype=bool,
                           count=len(valid_data))

    # Plot
    def get_embedding(self):
        """A get embedding method.

        Derived classes must override this method. The overridden method
        should return embedding for all data (valid and invalid). Invalid
        data embedding coordinates should be set to 0 (in some cases to Nan).

        The method should also set self.valid_data.

        Returns:
            np.array: Array of embedding coordinates with shape
            len(self.data) x 2
        """
        raise NotImplementedError

    def get_coordinates_data(self):
        embedding = self.get_embedding()
        if embedding is not None and len(embedding[self.valid_data]):
            return embedding[self.valid_data].T
        return None, None

    def setup_plot(self):
        self.graph.reset_graph()
        self.__pending_selection = self.selection or self.__pending_selection
        self.apply_selection()

    # Selection
    def apply_selection(self):
        pending = self.__pending_selection
        if self.data is not None and pending is not None and len(pending) \
                and max(i for i, _ in pending) < self.graph.n_valid:
            index_group = np.array(pending).T
            selection = np.zeros(self.graph.n_valid, dtype=np.uint8)
            selection[index_group[0]] = index_group[1]

            self.selection = self.__pending_selection
            self.__pending_selection = None
            self.graph.selection = selection
            self.graph.update_selection_colors()
            if self.graph.label_only_selected:
                self.graph.update_labels()

    def selection_changed(self):
        sel = None if self.data and isinstance(self.data, SqlTable) \
            else self.graph.selection
        self.selection = [(i, x) for i, x in enumerate(sel) if x] \
            if sel is not None else None
        self.commit()

    # Output
    def commit(self):
        self.send_data()

    def send_data(self):
        group_sel, data, graph = None, self._get_projection_data(), self.graph
        if graph.selection is not None:
            group_sel = np.zeros(len(data), dtype=int)
            group_sel[self.valid_data] = graph.selection
        selected = self._get_selected_data(data, graph.get_selection(),
                                           group_sel)
        self.output_changed.emit(selected)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(
            self._get_annotated_data(data, group_sel, graph.selection))

    def _get_projection_data(self):
        if self.data is None or self.embedding_variables_names is None:
            return self.data
        variables = self._get_projection_variables()
        data = self.data.transform(
            Domain(self.data.domain.attributes, self.data.domain.class_vars,
                   self.data.domain.metas + variables))
        with data.unlocked(data.metas):
            data.metas[:, -2:] = self.get_embedding()
        return data

    def _get_projection_variables(self):
        names = get_unique_names(self.data.domain,
                                 self.embedding_variables_names)
        return ContinuousVariable(names[0]), ContinuousVariable(names[1])

    @staticmethod
    def _get_selected_data(data, selection, group_sel):
        return create_groups_table(data, group_sel, False, "Group") \
            if len(selection) else None

    @staticmethod
    def _get_annotated_data(data, group_sel, graph_sel):
        if data is None:
            return None
        if graph_sel is not None and np.max(graph_sel) > 1:
            return create_groups_table(data, group_sel)
        else:
            if group_sel is None:
                mask = np.full((len(data), ), False)
            else:
                mask = np.nonzero(group_sel)[0]
            return create_annotated_table(data, mask)

    # Report
    def send_report(self):
        if self.data is None:
            return

        caption = self._get_send_report_caption()
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def _get_send_report_caption(self):
        return report.render_items_vert(
            (("Color", self._get_caption_var_name(self.attr_color)),
             ("Label", self._get_caption_var_name(self.attr_label)),
             ("Shape", self._get_caption_var_name(self.attr_shape)),
             ("Size", self._get_caption_var_name(self.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))

    # Customize plot
    def set_visual_settings(self, key, value):
        self.graph.parameter_setter.set_parameter(key, value)
        self.visual_settings[key] = value

    @staticmethod
    def _get_caption_var_name(var):
        return var.name if isinstance(var, Variable) else var

    # Misc
    def sizeHint(self):
        return QSize(1132, 708)

    def clear(self):
        self.selection = None
        self.graph.selection = None

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()
        self.graph.clear()
Ejemplo n.º 17
0
class OWDataProjectionWidget(OWProjectionWidgetBase):
    """
    Base widget for widgets that get Data and Data Subset (both
    Orange.data.Table) on the input, and output Selected Data and Data
    (both Orange.data.Table).

    Beside that the widget displays data as two-dimensional projection
    of points.
    """
    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler()
    selection = Setting(None, schema_only=True)
    auto_commit = Setting(True)

    GRAPH_CLASS = OWScatterPlotBase
    graph = SettingProvider(OWScatterPlotBase)
    graph_name = "graph.plot_widget.plotItem"
    embedding_variables_names = ("proj-x", "proj-y")

    def __init__(self):
        super().__init__()
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.setup_gui()

    # GUI
    def setup_gui(self):
        self._add_graph()
        self._add_controls()

    def _add_graph(self):
        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = self.GRAPH_CLASS(self, box)
        box.layout().addWidget(self.graph.plot_widget)

    def _add_controls(self):
        self._point_box = self.graph.gui.point_properties_box(self.controlArea)
        self._effects_box = self.graph.gui.effects_box(self.controlArea)
        self._plot_box = self.graph.gui.plot_properties_box(self.controlArea)
        self.control_area_stretch = gui.widgetBox(self.controlArea)
        self.control_area_stretch.layout().addStretch(100)
        self.graph.box_zoom_select(self.controlArea)
        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

    # Input
    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        same_domain = (self.data and data and
                       data.domain.checksum() == self.data.domain.checksum())
        self.closeContext()
        self.clear()
        self.data = data
        self.check_data()
        if not same_domain:
            self.init_attr_values()
        self.openContext(self.data)
        self.cb_class_density.setEnabled(self.can_draw_density())

    def check_data(self):
        self.clear_messages()

    @Inputs.data_subset
    @check_sql_input
    def set_subset_data(self, subset):
        self.subset_data = subset
        self.subset_indices = {e.id for e in subset} \
            if subset is not None else {}
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        self.setup_plot()
        self.commit()

    def get_subset_mask(self):
        if self.subset_indices:
            return np.array([ex.id in self.subset_indices
                             for ex in self.data[self.valid_data]])
        return None

    # Plot
    def get_embedding(self):
        """A get embedding method.

        Derived classes must override this method. The overridden method
        should return embedding for all data (valid and invalid). Invalid
        data embedding coordinates should be set to 0 (in some cases to Nan).

        The method should also sets self.valid_data.

        Returns:
            np.array: Array of embedding coordinates with shape
            len(self.data) x 2
        """
        raise NotImplementedError

    def get_coordinates_data(self):
        embedding = self.get_embedding()
        return embedding[self.valid_data].T[:2] if embedding is not None \
            else (None, None)

    def setup_plot(self):
        self.graph.reset_graph()
        self.__pending_selection = self.selection or self.__pending_selection
        self.apply_selection()

    # Selection
    def apply_selection(self):
        if self.data is not None and self.__pending_selection is not None \
                and self.graph.n_valid:
            index_group = [(index, group) for index, group in
                           self.__pending_selection if index < len(self.data)]
            index_group = np.array(index_group).T
            selection = np.zeros(self.graph.n_valid, dtype=np.uint8)
            selection[index_group[0]] = index_group[1]

            self.selection = self.__pending_selection
            self.__pending_selection = None
            self.graph.selection = selection
            self.graph.update_selection_colors()

    def selection_changed(self):
        sel = None if self.data and isinstance(self.data, SqlTable) \
            else self.graph.selection
        self.selection = [(i, x) for i, x in enumerate(sel) if x] \
            if sel is not None else None
        self.commit()

    # Output
    def commit(self):
        self.send_data()

    def send_data(self):
        group_sel, data, graph = None, self._get_projection_data(), self.graph
        if graph.selection is not None:
            group_sel = np.zeros(len(data), dtype=int)
            group_sel[self.valid_data] = graph.selection
        self.Outputs.selected_data.send(
            self._get_selected_data(data, graph.get_selection(), group_sel))
        self.Outputs.annotated_data.send(
            self._get_annotated_data(data, graph.get_selection(), group_sel,
                                     graph.selection))

    def _get_projection_data(self):
        if self.data is None or self.embedding_variables_names is None:
            return self.data
        variables = self._get_projection_variables()
        data = self.data.transform(Domain(self.data.domain.attributes,
                                          self.data.domain.class_vars,
                                          self.data.domain.metas + variables))
        data.metas[:, -2:] = self.get_embedding()
        return data

    def _get_projection_variables(self):
        domain = self.data.domain
        names = get_unique_names(
            [v.name for v in domain.variables + domain.metas],
            self.embedding_variables_names
        )
        return ContinuousVariable(names[0]), ContinuousVariable(names[1])

    @staticmethod
    def _get_selected_data(data, selection, group_sel):
        return create_groups_table(data, group_sel, False, "Group") \
            if len(selection) else None

    @staticmethod
    def _get_annotated_data(data, selection, group_sel, graph_sel):
        if graph_sel is not None and np.max(graph_sel) > 1:
            return create_groups_table(data, group_sel)
        else:
            return create_annotated_table(data, selection)

    # Report
    def send_report(self):
        if self.data is None:
            return

        caption = self._get_send_report_caption()
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def _get_send_report_caption(self):
        return report.render_items_vert((
            ("Color", self._get_caption_var_name(self.attr_color)),
            ("Label", self._get_caption_var_name(self.attr_label)),
            ("Shape", self._get_caption_var_name(self.attr_shape)),
            ("Size", self._get_caption_var_name(self.attr_size)),
            ("Jittering", self.graph.jitter_size != 0 and
             "{} %".format(self.graph.jitter_size))))

    @staticmethod
    def _get_caption_var_name(var):
        return var.name if isinstance(var, Variable) else var

    # Misc
    def sizeHint(self):
        return QSize(1132, 708)

    def clear(self):
        self.data = None
        self.valid_data = None
        self.selection = None
        self.graph.selection = None

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()
Ejemplo n.º 18
0
class OWtSNE(OWDataProjectionWidget, ConcurrentWidgetMixin):
    name = "t-SNE"
    description = "Two-dimensional data projection with t-SNE."
    icon = "icons/TSNE.svg"
    priority = 920
    keywords = ["tsne"]

    settings_version = 4
    perplexity = ContextSetting(30)
    multiscale = ContextSetting(False)
    exaggeration = ContextSetting(1)
    pca_components = ContextSetting(_DEFAULT_PCA_COMPONENTS)
    normalize = ContextSetting(True)

    GRAPH_CLASS = OWtSNEGraph
    graph = SettingProvider(OWtSNEGraph)
    embedding_variables_names = ("t-SNE-x", "t-SNE-y")

    # Use `invalidated` descriptor so we don't break the usage of
    # `_invalidated` in `OWDataProjectionWidget`, but still allow finer control
    # over which parts of the embedding to invalidate
    _invalidated = invalidated()

    class Information(OWDataProjectionWidget.Information):
        modified = Msg("The parameter settings have been changed. Press "
                       "\"Start\" to rerun with the new settings.")

    class Error(OWDataProjectionWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        not_enough_cols = Msg("Input data needs at least 2 attributes")
        constant_data = Msg("Input data is constant")
        no_valid_data = Msg("No projection due to no valid data")

    def __init__(self):
        OWDataProjectionWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        self.pca_projection = None  # type: Optional[Table]
        self.initialization = None  # type: Optional[np.ndarray]
        self.affinities = None      # type: Optional[openTSNE.affinity.Affinities]
        self.tsne_embedding = None  # type: Optional[manifold.TSNEModel]
        self.iterations_done = 0    # type: int

    def _add_controls(self):
        self._add_controls_start_box()
        super()._add_controls()

    def _add_controls_start_box(self):
        box = gui.vBox(self.controlArea, box="Optimize")
        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
        )

        self.perplexity_spin = gui.spin(
            box, self, "perplexity", 1, 500, step=1, alignment=Qt.AlignRight,
            callback=self._invalidate_affinities, addToLayout=False
        )
        self.controls.perplexity.setDisabled(self.multiscale)
        form.addRow("Perplexity:", self.perplexity_spin)
        form.addRow(gui.checkBox(
            box, self, "multiscale", label="Preserve global structure",
            callback=self._multiscale_changed, addToLayout=False
        ))

        sbe = gui.hBox(self.controlArea, False, addToLayout=False)
        gui.hSlider(
            sbe, self, "exaggeration", minValue=1, maxValue=4, step=1,
            callback=self._invalidate_tsne_embedding,
        )
        form.addRow("Exaggeration:", sbe)

        sbp = gui.hBox(self.controlArea, False, addToLayout=False)
        gui.hSlider(
            sbp, self, "pca_components", minValue=2, maxValue=_MAX_PCA_COMPONENTS,
            step=1, callback=self._invalidate_pca_projection,
        )
        form.addRow("PCA components:", sbp)

        self.normalize_cbx = gui.checkBox(
            box, self, "normalize", "Normalize data",
            callback=self._invalidate_pca_projection, addToLayout=False
        )
        form.addRow(self.normalize_cbx)

        box.layout().addLayout(form)

        self.run_button = gui.button(box, self, "Start", callback=self._toggle_run)

    def _multiscale_changed(self):
        self.controls.perplexity.setDisabled(self.multiscale)
        self._invalidate_affinities()

    def _invalidate_pca_projection(self):
        self._invalidated.pca_projection = True
        self._invalidate_affinities()

    def _invalidate_affinities(self):
        self._invalidated.affinities = True
        self._invalidate_tsne_embedding()

    def _invalidate_tsne_embedding(self):
        self._invalidated.tsne_embedding = True
        self._stop_running_task()
        self._set_modified(True)

    def _stop_running_task(self):
        self.cancel()
        self.run_button.setText("Start")

    def _set_modified(self, state):
        """Mark the widget (GUI) as containing modified state."""
        if self.data is None:
            # Does not apply when we have no data
            state = False
        self.Information.modified(shown=state)

    def check_data(self):
        def error(err):
            err()
            self.data = None

        # `super().check_data()` clears all messages so we have to remember if
        # it was shown
        # pylint: disable=assignment-from-no-return
        should_show_modified_message = self.Information.modified.is_shown()
        super().check_data()

        if self.data is None:
            return

        self.Information.modified(shown=should_show_modified_message)

        if len(self.data) < 2:
            error(self.Error.not_enough_rows)

        elif len(self.data.domain.attributes) < 2:
            error(self.Error.not_enough_cols)

        elif not self.data.is_sparse():
            if np.all(~np.isfinite(self.data.X)):
                error(self.Error.no_valid_data)
            else:
                with warnings.catch_warnings():
                    warnings.filterwarnings(
                        "ignore", "Degrees of freedom .*", RuntimeWarning)
                    if np.nan_to_num(np.nanstd(self.data.X, axis=0)).sum() \
                            == 0:
                        error(self.Error.constant_data)

    def get_embedding(self):
        if self.tsne_embedding is None:
            self.valid_data = None
            return None

        embedding = self.tsne_embedding.embedding.X
        self.valid_data = np.ones(len(embedding), dtype=bool)
        return embedding

    def _toggle_run(self):
        # If no data, there's nothing to do
        if self.data is None:
            return

        # Pause task
        if self.task is not None:
            self.cancel()
            self.run_button.setText("Resume")
            self.commit.deferred()
        # Resume task
        else:
            self.run()

    def handleNewSignals(self):
        # We don't bother with the granular invalidation flags because
        # `super().handleNewSignals` will just set all of them to False or will
        # do nothing. However, it's important we remember its state because we
        # won't call `run` if needed. `run` also relies on the state of
        # `_invalidated` to properly set the intermediate values to None
        prev_invalidated = bool(self._invalidated)
        super().handleNewSignals()
        self._invalidated = prev_invalidated

        if self._invalidated:
            self.run()

    def init_attr_values(self):
        super().init_attr_values()

        if self.data is not None:
            n_attrs = len(self.data.domain.attributes)
            max_components = min(_MAX_PCA_COMPONENTS, n_attrs)
        else:
            max_components = _MAX_PCA_COMPONENTS

        # We set this to the default number of components here so it resets
        # properly, any previous settings will be restored from context
        # settings a little later
        self.controls.pca_components.setMaximum(max_components)
        self.controls.pca_components.setValue(_DEFAULT_PCA_COMPONENTS)

        self.exaggeration = 1

    def enable_controls(self):
        super().enable_controls()

        if self.data is not None:
            # PCA doesn't support normalization on sparse data, as this would
            # require centering and normalizing the matrix
            self.normalize_cbx.setDisabled(self.data.is_sparse())
            if self.data.is_sparse():
                self.normalize = False
                self.normalize_cbx.setToolTip(
                    "Data normalization is not supported on sparse matrices."
                )
            else:
                self.normalize_cbx.setToolTip("")

        # Disable the perplexity spin box if multiscale is turned on
        self.controls.perplexity.setDisabled(self.multiscale)

    def run(self):
        # Reset invalidated values as indicated by the flags
        if self._invalidated.pca_projection:
            self.pca_projection = None
        if self._invalidated.affinities:
            self.affinities = None
        if self._invalidated.tsne_embedding:
            self.iterations_done = 0
            self.tsne_embedding = None

        self._set_modified(False)
        self._invalidated = False

        # When the data is invalid, it is set to `None` and an error is set,
        # therefore it would be erroneous to clear the error here
        if self.data is not None:
            self.run_button.setText("Stop")

        # Cancel current running task
        self.cancel()

        if self.data is None:
            return

        task = Task(
            data=self.data,
            normalize=self.normalize,
            pca_components=self.pca_components,
            pca_projection=self.pca_projection,
            perplexity=self.perplexity,
            multiscale=self.multiscale,
            exaggeration=self.exaggeration,
            initialization=self.initialization,
            affinities=self.affinities,
            tsne_embedding=self.tsne_embedding,
            iterations_done=self.iterations_done,
        )
        return self.start(TSNERunner.run, task)

    def __ensure_task_same_for_pca(self, task: Task):
        assert self.data is not None
        assert task.normalize == self.normalize
        assert task.pca_components == self.pca_components
        assert isinstance(task.pca_projection, Table) and \
            len(task.pca_projection) == len(self.data)

    def __ensure_task_same_for_initialization(self, task: Task):
        assert isinstance(task.initialization, np.ndarray) and \
            len(task.initialization) == len(self.data)

    def __ensure_task_same_for_affinities(self, task: Task):
        assert task.perplexity == self.perplexity
        assert task.multiscale == self.multiscale

    def __ensure_task_same_for_embedding(self, task: Task):
        assert task.exaggeration == self.exaggeration
        assert isinstance(task.tsne_embedding, manifold.TSNEModel) and \
            len(task.tsne_embedding.embedding) == len(self.data)

    def on_partial_result(self, value):
        # type: (Tuple[str, Task]) -> None
        which, task = value

        if which == "pca_projection":
            self.__ensure_task_same_for_pca(task)
            self.pca_projection = task.pca_projection
        elif which == "initialization":
            self.__ensure_task_same_for_pca(task)
            self.__ensure_task_same_for_initialization(task)
            self.initialization = task.initialization
        elif which == "affinities":
            self.__ensure_task_same_for_pca(task)
            self.__ensure_task_same_for_affinities(task)
            self.affinities = task.affinities
        elif which == "tsne_embedding":
            self.__ensure_task_same_for_pca(task)
            self.__ensure_task_same_for_initialization(task)
            self.__ensure_task_same_for_affinities(task)
            self.__ensure_task_same_for_embedding(task)

            prev_embedding, self.tsne_embedding = self.tsne_embedding, task.tsne_embedding
            self.iterations_done = task.iterations_done
            # If this is the first partial result we've gotten, we've got to
            # setup the plot
            if prev_embedding is None:
                self.setup_plot()
            # Otherwise, just update the point positions
            else:
                self.graph.update_coordinates()
                self.graph.update_density()
        else:
            raise RuntimeError(
                "Unrecognized partial result called with `%s`" % which
            )

    def on_done(self, task):
        # type: (Task) -> None
        self.run_button.setText("Start")
        # NOTE: All of these have already been set by on_partial_result,
        # we double check that they are aliases
        if task.pca_projection is not None:
            self.__ensure_task_same_for_pca(task)
            assert task.pca_projection is self.pca_projection
        if task.initialization is not None:
            self.__ensure_task_same_for_initialization(task)
            assert task.initialization is self.initialization
        if task.affinities is not None:
            assert task.affinities is self.affinities
        if task.tsne_embedding is not None:
            self.__ensure_task_same_for_embedding(task)
            assert task.tsne_embedding is self.tsne_embedding

        self.commit.deferred()

    def _get_projection_data(self):
        if self.data is None:
            return None

        data = self.data.transform(
            Domain(
                self.data.domain.attributes,
                self.data.domain.class_vars,
                self.data.domain.metas + self._get_projection_variables()
            )
        )
        with data.unlocked(data.metas):
            data.metas[:, -2:] = self.get_embedding()
        if self.tsne_embedding is not None:
            data.domain = Domain(
                self.data.domain.attributes,
                self.data.domain.class_vars,
                self.data.domain.metas + self.tsne_embedding.domain.attributes,
            )
        return data

    def clear(self):
        """Clear widget state. Note that this doesn't clear the data."""
        super().clear()
        self.run_button.setText("Start")
        self.cancel()
        self.pca_projection = None
        self.initialization = None
        self.affinities = None
        self.tsne_embedding = None
        self.iterations_done = 0

    def onDeleteWidget(self):
        self.clear()
        self.data = None
        self.shutdown()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 3:
            if "selection_indices" in settings:
                settings["selection"] = settings["selection_indices"]
        if version < 4:
            settings.pop("max_iter", None)

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Ejemplo n.º 19
0
class OWMDS(OWDataProjectionWidget, ConcurrentWidgetMixin):
    name = "MDS"
    description = "Two-dimensional data projection by multidimensional " \
                  "scaling constructed from a distance matrix."
    icon = "icons/MDS.svg"
    keywords = ["multidimensional scaling", "multi dimensional scaling"]

    class Inputs(OWDataProjectionWidget.Inputs):
        distances = Input("Distances", DistMatrix)

    settings_version = 3

    #: Initialization type
    PCA, Random, Jitter = 0, 1, 2

    #: Refresh rate
    RefreshRate = [("Every iteration", 1), ("Every 5 steps", 5),
                   ("Every 10 steps", 10), ("Every 25 steps", 25),
                   ("Every 50 steps", 50), ("None", -1)]

    max_iter = settings.Setting(300)
    initialization = settings.Setting(PCA)
    refresh_rate = settings.Setting(3)

    GRAPH_CLASS = OWMDSGraph
    graph = SettingProvider(OWMDSGraph)
    embedding_variables_names = ("mds-x", "mds-y")

    class Error(OWDataProjectionWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        matrix_too_small = Msg("Input matrix must be at least 2x2")
        no_attributes = Msg("Data has no attributes")
        mismatching_dimensions = \
            Msg("Data and distances dimensions do not match.")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")

    def __init__(self):
        OWDataProjectionWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)
        #: Input dissimilarity matrix
        self.matrix = None  # type: Optional[DistMatrix]
        #: Data table from the `self.matrix.row_items` (if present)
        self.matrix_data = None  # type: Optional[Table]
        #: Input data table
        self.signal_data = None

        self.embedding = None  # type: Optional[np.ndarray]
        self.effective_matrix = None  # type: Optional[DistMatrix]

        self.graph.pause_drawing_pairs()

        self.size_model = self.gui.points_models[2]
        self.size_model.order = \
            self.gui.points_models[2].order[:1] \
            + ("Stress", ) + \
            self.gui.points_models[2].order[1:]

    def _add_controls(self):
        self._add_controls_optimization()
        super()._add_controls()
        self.gui.add_control(self._effects_box,
                             gui.hSlider,
                             "Show similar pairs:",
                             master=self.graph,
                             value="connected_pairs",
                             minValue=0,
                             maxValue=20,
                             createLabel=False,
                             callback=self._on_connected_changed)

    def _add_controls_optimization(self):
        box = gui.vBox(self.controlArea, box=True)
        self.run_button = gui.button(box, self, "Start", self._toggle_run)
        gui.comboBox(box,
                     self,
                     "refresh_rate",
                     label="Refresh: ",
                     orientation=Qt.Horizontal,
                     items=[t for t, _ in OWMDS.RefreshRate],
                     callback=self.__refresh_rate_combo_changed)
        hbox = gui.hBox(box, margin=0)
        gui.button(hbox, self, "PCA", callback=self.do_PCA)
        gui.button(hbox, self, "Randomize", callback=self.do_random)
        gui.button(hbox, self, "Jitter", callback=self.do_jitter)

    def __refresh_rate_combo_changed(self):
        if self.task is not None:
            self._run()

    def set_data(self, data):
        """Set the input dataset.

        Parameters
        ----------
        data : Optional[Table]
        """
        if data is not None and len(data) < 2:
            self.Error.not_enough_rows()
            data = None
        else:
            self.Error.not_enough_rows.clear()

        self.signal_data = data

    @Inputs.distances
    def set_disimilarity(self, matrix):
        """Set the dissimilarity (distance) matrix.

        Parameters
        ----------
        matrix : Optional[Orange.misc.DistMatrix]
        """

        if matrix is not None and len(matrix) < 2:
            self.Error.matrix_too_small()
            matrix = None
        else:
            self.Error.matrix_too_small.clear()

        self.matrix = matrix
        self.matrix_data = matrix.row_items if matrix is not None else None

    def clear(self):
        super().clear()
        self.cancel()
        self.embedding = None
        self.graph.set_effective_matrix(None)

    def _initialize(self):
        matrix_existed = self.effective_matrix is not None
        effective_matrix = self.effective_matrix
        self._invalidated = True
        self.data = None
        self.effective_matrix = None
        self.closeContext()
        self.clear_messages()

        # if no data nor matrix is present reset plot
        if self.signal_data is None and self.matrix is None:
            self.clear()
            self.init_attr_values()
            return

        if self.signal_data is not None and self.matrix is not None and \
                len(self.signal_data) != len(self.matrix):
            self.Error.mismatching_dimensions()
            self.clear()
            self.init_attr_values()
            return

        if self.signal_data is not None:
            self.data = self.signal_data
        elif self.matrix_data is not None:
            self.data = self.matrix_data

        if self.matrix is not None:
            self.effective_matrix = self.matrix
            if self.matrix.axis == 0 and self.data is not None \
                    and self.data is self.matrix_data:
                names = [[attr.name] for attr in self.data.domain.attributes]
                domain = Domain([], metas=[StringVariable("labels")])
                self.data = Table.from_list(domain, names)
        elif self.data.domain.attributes:
            preprocessed_data = MDS().preprocess(self.data)
            self.effective_matrix = Euclidean(preprocessed_data)
        else:
            self.Error.no_attributes()
            self.clear()
            self.init_attr_values()
            return

        self.init_attr_values()
        self.openContext(self.data)
        self._invalidated = not (
            matrix_existed and self.effective_matrix is not None
            and array_equal(effective_matrix, self.effective_matrix))
        if self._invalidated:
            self.clear()
        self.graph.set_effective_matrix(self.effective_matrix)

    def init_attr_values(self):
        super().init_attr_values()
        if self.matrix is not None and self.matrix.axis == 0 and \
                self.data is not None and len(self.data):
            self.attr_label = self.data.domain["labels"]

    def _toggle_run(self):
        if self.task is not None:
            self.cancel()
            self.run_button.setText("Resume")
            self.commit()
        else:
            self._run()

    def _run(self):
        if self.effective_matrix is None:
            return
        self.graph.pause_drawing_pairs()
        self.run_button.setText("Stop")
        _, step_size = OWMDS.RefreshRate[self.refresh_rate]
        if step_size == -1:
            step_size = self.max_iter
        init_type = "PCA" if self.initialization == OWMDS.PCA else "random"
        self.start(run_mds, self.effective_matrix, self.max_iter, step_size,
                   init_type, self.embedding)

    # ConcurrentWidgetMixin
    def on_partial_result(self, result: Result):
        assert isinstance(result.embedding, np.ndarray)
        assert len(result.embedding) == len(self.effective_matrix)
        first_result = self.embedding is None
        new_embedding = result.embedding
        need_update = new_embedding is not self.embedding
        self.embedding = new_embedding
        if first_result:
            self.setup_plot()
        else:
            if need_update:
                self.graph.update_coordinates()
                self.graph.update_density()

    def on_done(self, result: Result):
        assert isinstance(result.embedding, np.ndarray)
        assert len(result.embedding) == len(self.effective_matrix)
        self.embedding = result.embedding
        self.graph.resume_drawing_pairs()
        self.run_button.setText("Start")
        self.commit()

    def on_exception(self, ex: Exception):
        if isinstance(ex, MemoryError):
            self.Error.out_of_memory()
        else:
            self.Error.optimization_error(str(ex))
        self.graph.resume_drawing_pairs()
        self.run_button.setText("Start")

    def do_PCA(self):
        self.do_initialization(self.PCA)

    def do_random(self):
        self.do_initialization(self.Random)

    def do_jitter(self):
        self.do_initialization(self.Jitter)

    def do_initialization(self, init_type: int):
        self.run_button.setText("Start")
        self.__invalidate_embedding(init_type)
        self.setup_plot()
        self.commit()

    def __invalidate_embedding(self, initialization=PCA):
        def jitter_coord(part):
            span = np.max(part) - np.min(part)
            part += np.random.uniform(-span / 20, span / 20, len(part))

        # reset/invalidate the MDS embedding, to the default initialization
        # (Random or PCA), restarting the optimization if necessary.
        if self.effective_matrix is None:
            self.graph.reset_graph()
            return

        X = self.effective_matrix

        if initialization == OWMDS.PCA:
            self.embedding = torgerson(X)
        elif initialization == OWMDS.Random:
            self.embedding = np.random.rand(len(X), 2)
        else:
            jitter_coord(self.embedding[:, 0])
            jitter_coord(self.embedding[:, 1])

        # restart the optimization if it was interrupted.
        if self.task is not None:
            self._run()

    def handleNewSignals(self):
        self._initialize()
        self.input_changed.emit(self.data)
        if self._invalidated:
            self.graph.pause_drawing_pairs()
            self.__invalidate_embedding()
            self.enable_controls()
            if self.effective_matrix is not None:
                self._run()
        super().handleNewSignals()

    def _on_connected_changed(self):
        self.graph.set_effective_matrix(self.effective_matrix)
        self.graph.update_pairs(reconnect=True)

    def setup_plot(self):
        super().setup_plot()
        if self.embedding is not None:
            self.graph.update_pairs(reconnect=True)

    def get_size_data(self):
        if self.attr_size == "Stress":
            return stress(self.embedding, self.effective_matrix)
        else:
            return super().get_size_data()

    def get_embedding(self):
        self.valid_data = np.ones(len(self.embedding), dtype=bool) \
            if self.embedding is not None else None
        return self.embedding

    def _get_projection_data(self):
        if self.embedding is None:
            return None

        if self.data is None:
            x_name, y_name = self.embedding_variables_names
            variables = ContinuousVariable(x_name), ContinuousVariable(y_name)
            return Table(Domain(variables), self.embedding)
        return super()._get_projection_data()

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_graph = {}
            for old, new in (("label_only_selected", "label_only_selected"),
                             ("symbol_opacity", "alpha_value"),
                             ("symbol_size", "point_width"), ("jitter",
                                                              "jitter_size")):
                settings_graph[new] = settings_[old]
            settings_["graph"] = settings_graph
            settings_["auto_commit"] = settings_["autocommit"]

        if version < 3:
            if "connected_pairs" in settings_:
                connected_pairs = settings_["connected_pairs"]
                settings_["graph"]["connected_pairs"] = connected_pairs

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            n_domain = [t for t in context.ordered_domain if t[1] == 2]
            c_domain = [t for t in context.ordered_domain if t[1] == 1]
            context_values = {}
            for _, old_val, new_val in ((domain, "color_value", "attr_color"),
                                        (c_domain, "shape_value",
                                         "attr_shape"),
                                        (n_domain, "size_value", "attr_size"),
                                        (domain, "label_value", "attr_label")):
                tmp = context.values[old_val]
                if tmp[1] >= 0:
                    context_values[new_val] = (tmp[0], tmp[1] + 100)
                elif tmp[0] != "Stress":
                    context_values[new_val] = None
                else:
                    context_values[new_val] = tmp
            context.values = context_values

        if version < 3 and "graph" in context.values:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Ejemplo n.º 20
0
class OWScatterPlot(OWWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        features = Input("Features", AttributeList)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        features = Output("Features", Table, dynamic=False)

    settingsHandler = DomainContextHandler()

    auto_send_selection = Setting(True)
    auto_sample = Setting(True)
    toolbar_selection = Setting(0)

    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    selection = Setting(None, schema_only=True)

    graph = SettingProvider(OWScatterPlotGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Information(OWWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")

    def __init__(self):
        super().__init__()

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWScatterPlotGraph(self, box, "ScatterPlot")
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        axispen = QPen(self.palette().color(QPalette.Text))
        axis = plot.getAxis("bottom")
        axis.setPen(axispen)

        axis = plot.getAxis("left")
        axis.setPen(axispen)

        self.data = None  # Orange.data.Table
        self.subset_data = None  # Orange.data.Table
        self.data_metas_X = None  # self.data, where primitive metas are moved to X
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)

        common_options = dict(
            labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True,
            valueType=str)
        box = gui.vBox(self.controlArea, "Axis Data")
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(
            box, self, "attr_x", label="Axis x:", callback=self.update_attr,
            model=self.xy_model, **common_options)
        self.cb_attr_y = gui.comboBox(
            box, self, "attr_y", label="Axis y:", callback=self.update_attr,
            model=self.xy_model, **common_options)

        vizrank_box = gui.hBox(box)
        gui.separator(vizrank_box, width=common_options["labelWidth"])
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

        gui.separator(box)

        gui.valueSlider(
            box, self, value='graph.jitter_size', label='Jittering: ',
            values=self.jitter_sizes, callback=self.reset_graph_data,
            labelFormat=lambda x:
            "None" if x == 0 else ("%.1f %%" if x < 1 else "%d %%") % x)
        gui.checkBox(
            gui.indentedBox(box), self, 'graph.jitter_continuous',
            'Jitter numeric values', callback=self.reset_graph_data)

        self.sampling = gui.auto_commit(
            self.controlArea, self, "auto_sample", "Sample", box="Sampling",
            callback=self.switch_sampling, commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

        g = self.graph.gui
        g.point_properties_box(self.controlArea)
        self.models = [self.xy_model] + g.points_models

        box = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([g.ShowLegend, g.ShowGridLines], box)
        gui.checkBox(
            box, self, value='graph.tooltip_shows_all',
            label='Show all data on mouse hover')
        self.cb_class_density = gui.checkBox(
            box, self, value='graph.class_density', label='Show class density',
            callback=self.update_density)
        self.cb_reg_line = gui.checkBox(
            box, self, value='graph.show_reg_line',
            label='Show regression line', callback=self.update_regression_line)
        gui.checkBox(
            box, self, 'graph.label_only_selected',
            'Label only selected points', callback=self.graph.update_labels)

        self.zoom_select_toolbar = g.zoom_select_toolbar(
            gui.vBox(self.controlArea, "Zoom/Select"), nomargin=True,
            buttons=[g.StateButtonsBegin, g.SimpleSelect, g.Pan, g.Zoom,
                     g.StateButtonsEnd, g.ZoomReset]
        )
        buttons = self.zoom_select_toolbar.buttons
        buttons[g.Zoom].clicked.connect(self.graph.zoom_button_clicked)
        buttons[g.Pan].clicked.connect(self.graph.pan_button_clicked)
        buttons[g.SimpleSelect].clicked.connect(self.graph.select_button_clicked)
        buttons[g.ZoomReset].clicked.connect(self.graph.reset_button_clicked)
        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_send_selection",
                        "Send Selection", "Send Automatically")

        def zoom(s):
            """Zoom in/out by factor `s`."""
            viewbox = plot.getViewBox()
            # scaleBy scales the view's bounds (the axis range)
            viewbox.scaleBy((1 / s, 1 / s))

        def fit_to_view():
            viewbox = plot.getViewBox()
            viewbox.autoRange()

        zoom_in = QAction(
            "Zoom in", self, triggered=lambda: zoom(1.25)
        )
        zoom_in.setShortcuts([QKeySequence(QKeySequence.ZoomIn),
                              QKeySequence(self.tr("Ctrl+="))])
        zoom_out = QAction(
            "Zoom out", self, shortcut=QKeySequence.ZoomOut,
            triggered=lambda: zoom(1 / 1.25)
        )
        zoom_fit = QAction(
            "Fit in view", self,
            shortcut=QKeySequence(Qt.ControlModifier | Qt.Key_0),
            triggered=fit_to_view
        )
        self.addActions([zoom_in, zoom_out, zoom_fit])

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    # def settingsFromWidgetCallback(self, handler, context):
    #     context.selectionPolygons = []
    #     for curve in self.graph.selectionCurveList:
    #         xs = [curve.x(i) for i in range(curve.dataSize())]
    #         ys = [curve.y(i) for i in range(curve.dataSize())]
    #         context.selectionPolygons.append((xs, ys))

    # def settingsToWidgetCallback(self, handler, context):
    #     selections = getattr(context, "selectionPolygons", [])
    #     for (xs, ys) in selections:
    #         c = SelectionCurve("")
    #         c.setData(xs,ys)
    #         c.attach(self.graph)
    #         self.graph.selectionCurveList.append(c)

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.Information.sampled_sql.clear()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(data, SqlTable):
            if data.approx_len() < 4000:
                data = Table(data)
            else:
                self.Information.sampled_sql()
                self.sql_data = data
                data_sample = data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if data is not None and (len(data) == 0 or len(data.domain) == 0):
            data = None
        if self.data and data and self.data.checksum() == data.checksum():
            return

        self.closeContext()
        same_domain = (self.data and data and
                       data.domain.checksum() == self.data.domain.checksum())
        self.data = data
        self.data_metas_X = self.move_primitive_metas_to_X(data)

        if not same_domain:
            self.init_attr_values()
        self.vizrank.initialize()
        self.vizrank.attrs = self.data.domain.attributes if self.data is not None else []
        self.vizrank_button.setEnabled(
            self.data is not None and not self.data.is_sparse() and
            self.data.domain.class_var is not None and
            len(self.data.domain.attributes) > 1 and len(self.data) > 1)
        if self.data is not None and self.data.domain.class_var is None \
            and len(self.data.domain.attributes) > 1 and len(self.data) > 1:
            self.vizrank_button.setToolTip(
                "Data with a class variable is required.")
        else:
            self.vizrank_button.setToolTip("")
        self.openContext(self.data)

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Orange.data.Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.graph.attr_label, str):
            self.graph.attr_label = findvar(
                self.graph.attr_label, self.graph.gui.label_model)
        if isinstance(self.graph.attr_color, str):
            self.graph.attr_color = findvar(
                self.graph.attr_color, self.graph.gui.color_model)
        if isinstance(self.graph.attr_shape, str):
            self.graph.attr_shape = findvar(
                self.graph.attr_shape, self.graph.gui.shape_model)
        if isinstance(self.graph.attr_size, str):
            self.graph.attr_size = findvar(
                self.graph.attr_size, self.graph.gui.size_model)

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            return self.__timer.stop()
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.data_metas_X = self.move_primitive_metas_to_X(self.data)
            self.handleNewSignals()

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def move_primitive_metas_to_X(self, data):
        if data is not None:
            new_attrs = [a for a in data.domain.attributes + data.domain.metas
                         if a.is_primitive()]
            new_metas = [m for m in data.domain.metas if not m.is_primitive()]
            new_domain = Domain(new_attrs, data.domain.class_vars, new_metas)
            data = data.transform(new_domain)
        return data

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        self.subset_data = self.move_primitive_metas_to_X(subset_data)
        self.controls.graph.alpha_value.setEnabled(subset_data is None)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
                            self.sparse_to_dense(self.subset_data))
        if self.attribute_selection_list and self.graph.domain and \
                all(attr in self.graph.domain
                        for attr in self.attribute_selection_list):
            self.attr_x = self.attribute_selection_list[0]
            self.attr_y = self.attribute_selection_list[1]
        self.attribute_selection_list = None
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        self.apply_selection()
        self.unconditional_commit()

    def prepare_data(self):
        """
        Only when dealing with sparse matrices.
        GH-2152
        """
        self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
                            self.sparse_to_dense(self.subset_data),
                            new=False)

    def sparse_to_dense(self, input_data=None):
        if input_data is None or not input_data.is_sparse():
            return input_data
        keys = []
        attrs = {self.attr_x,
                 self.attr_y,
                 self.graph.attr_color,
                 self.graph.attr_shape,
                 self.graph.attr_size,
                 self.graph.attr_label}
        for i, attr in enumerate(input_data.domain):
            if attr in attrs:
                keys.append(i)
        new_domain = input_data.domain.select_columns(keys)
        dmx = input_data.transform(new_domain)
        dmx.X = dmx.X.toarray()
        # TODO: remove once we make sure Y is always dense.
        if sp.issparse(dmx.Y):
            dmx.Y = dmx.Y.toarray()
        return dmx

    def apply_selection(self):
        """Apply selection saved in workflow."""
        if self.data is not None and self.selection is not None:
            self.graph.selection = np.zeros(len(self.data), dtype=np.uint8)
            self.selection = [x for x in self.selection if x < len(self.data)]
            self.graph.selection[self.selection] = 1
            self.graph.update_colors(keep_colors=True)

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def get_shown_attributes(self):
        return self.attr_x, self.attr_y

    def init_attr_values(self):
        domain = self.data and self.data.domain
        for model in self.models:
            model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x
        self.graph.attr_color = domain and self.data.domain.class_var or None
        self.graph.attr_shape = None
        self.graph.attr_size = None
        self.graph.attr_label = None

    def set_attr(self, attr_x, attr_y):
        self.attr_x, self.attr_y = attr_x, attr_y
        self.update_attr()

    def update_attr(self):
        self.prepare_data()
        self.update_graph()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())
        self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
        self.send_features()

    def update_colors(self):
        self.prepare_data()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.attr_x, self.attr_y, reset_view)

    def selection_changed(self):
        self.send_data()

    @staticmethod
    def create_groups_table(data, selection):
        if data is None:
            return None
        names = [var.name for var in data.domain.variables + data.domain.metas]
        name = get_next_name(names, "Selection group")
        metas = data.domain.metas + (
            DiscreteVariable(
                name,
                ["Unselected"] + ["G{}".format(i + 1)
                                  for i in range(np.max(selection))]),
        )
        domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
        table = data.transform(domain)
        table.metas[:, len(data.domain.metas):] = \
            selection.reshape(len(data), 1)
        return table

    def send_data(self):
        selected = None
        selection = None
        # TODO: Implement selection for sql data
        graph = self.graph
        if isinstance(self.data, SqlTable):
            selected = self.data
        elif self.data is not None:
            selection = graph.get_selection()
            if len(selection) > 0:
                selected = self.data[selection]
        if graph.selection is not None and np.max(graph.selection) > 1:
            annotated = self.create_groups_table(self.data, graph.selection)
        else:
            annotated = create_annotated_table(self.data, selection)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

        # Store current selection in a setting that is stored in workflow
        if self.selection is not None and len(selection):
            self.selection = list(selection)

    def send_features(self):
        features = None
        if self.attr_x or self.attr_y:
            dom = Domain([], metas=(StringVariable(name="feature"),))
            features = Table(dom, [[self.attr_x], [self.attr_y]])
            features.name = "Features"
        self.Outputs.features.send(features)

    def commit(self):
        self.send_data()
        self.send_features()

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)

    def send_report(self):
        if self.data is None:
            return
        def name(var):
            return var and var.name
        caption = report.render_items_vert((
            ("Color", name(self.graph.attr_color)),
            ("Label", name(self.graph.attr_label)),
            ("Shape", name(self.graph.attr_shape)),
            ("Size", name(self.graph.attr_size)),
            ("Jittering", (self.attr_x.is_discrete or
                           self.attr_y.is_discrete or
                           self.graph.jitter_continuous) and
             self.graph.jitter_size)))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self.graph.plot_widget.getViewBox().deleteLater()
        self.graph.plot_widget.clear()
Ejemplo n.º 21
0
class OWMosaicDisplay(OWWidget):
    name = "Mosaic Display"
    description = "Display data in a mosaic plot."
    icon = "icons/MosaicDisplay.svg"
    priority = 220
    keywords = []

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = DomainContextHandler()
    vizrank = SettingProvider(MosaicVizRank)
    settings_version = 2
    use_boxes = Setting(True)
    variable1: Variable = ContextSetting(None)
    variable2: Variable = ContextSetting(None)
    variable3: Variable = ContextSetting(None)
    variable4: Variable = ContextSetting(None)
    variable_color: DiscreteVariable = ContextSetting(None)
    selection = Setting(set(), schema_only=True)

    BAR_WIDTH = 5
    SPACING = 4
    ATTR_NAME_OFFSET = 20
    ATTR_VAL_OFFSET = 3
    BLUE_COLORS = [
        QColor(255, 255, 255),
        QColor(210, 210, 255),
        QColor(110, 110, 255),
        QColor(0, 0, 255)
    ]
    RED_COLORS = [
        QColor(255, 255, 255),
        QColor(255, 200, 200),
        QColor(255, 100, 100),
        QColor(255, 0, 0)
    ]
    graph_name = "canvas"

    attrs_changed_manually = Signal(list)

    class Warning(OWWidget.Warning):
        incompatible_subset = Msg("Data subset is incompatible with Data")
        no_valid_data = Msg("No valid data")
        no_cont_selection_sql = \
            Msg("Selection of numeric features on SQL is not supported")

    def __init__(self):
        super().__init__()

        self.data = None
        self.discrete_data = None
        self.subset_data = None
        self.subset_indices = None
        self.__pending_selection = self.selection
        self.selection = set()

        self.color_data = None

        self.areas = []

        self.canvas = QGraphicsScene(self)
        self.canvas_view = ViewWithPress(self.canvas,
                                         handler=self.clear_selection)
        self.mainArea.layout().addWidget(self.canvas_view)
        self.canvas_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.canvas_view.setRenderHint(QPainter.Antialiasing)

        box = gui.vBox(self.controlArea, box=True)
        self.model_1 = DomainModel(order=DomainModel.MIXED,
                                   valid_types=DomainModel.PRIMITIVE)
        self.model_234 = DomainModel(order=DomainModel.MIXED,
                                     valid_types=DomainModel.PRIMITIVE,
                                     placeholder="(None)")
        self.attr_combos = [
            gui.comboBox(box,
                         self,
                         value="variable{}".format(i),
                         orientation=Qt.Horizontal,
                         contentsLength=12,
                         searchable=True,
                         callback=self.attr_changed,
                         model=self.model_1 if i == 1 else self.model_234)
            for i in range(1, 5)
        ]
        self.vizrank, self.vizrank_button = MosaicVizRank.add_vizrank(
            box, self, "Find Informative Mosaics", self.set_attr)

        box2 = gui.vBox(self.controlArea, box="Interior Coloring")
        self.color_model = DomainModel(order=DomainModel.MIXED,
                                       valid_types=DomainModel.PRIMITIVE,
                                       placeholder="(Pearson residuals)")
        self.cb_attr_color = gui.comboBox(box2,
                                          self,
                                          value="variable_color",
                                          orientation=Qt.Horizontal,
                                          contentsLength=12,
                                          labelWidth=50,
                                          searchable=True,
                                          callback=self.set_color_data,
                                          model=self.color_model)
        self.bar_button = gui.checkBox(box2,
                                       self,
                                       'use_boxes',
                                       label='Compare with total',
                                       callback=self.update_graph)
        gui.rubber(self.controlArea)

    def sizeHint(self):
        return QSize(720, 530)

    def _get_discrete_data(self, data):
        """
        Discretize continuous attributes.
        Return None when there is no data, no rows, or no primitive attributes.
        """
        if (data is None or not len(data) or not any(
                attr.is_discrete or attr.is_continuous
                for attr in chain(data.domain.variables, data.domain.metas))):
            return None
        elif any(attr.is_continuous for attr in data.domain.variables):
            return Discretize(method=EqualFreq(n=4),
                              remove_const=False,
                              discretize_classes=True,
                              discretize_metas=True)(data)
        else:
            return data

    def init_combos(self, data):
        def set_combos(value):
            self.model_1.set_domain(value)
            self.model_234.set_domain(value)
            self.color_model.set_domain(value)

        if data is None:
            set_combos(None)
            self.variable1 = self.variable2 = self.variable3 \
                = self.variable4 = self.variable_color = None
            return
        set_combos(self.data.domain)

        if len(self.model_1) > 0:
            self.variable1 = self.model_1[0]
            self.variable2 = self.model_1[min(1, len(self.model_1) - 1)]
        self.variable3 = self.variable4 = None
        self.variable_color = self.data.domain.class_var  # None is OK, too

    def get_disc_attr_list(self):
        return [
            self.discrete_data.domain[var.name]
            for var in (self.variable1, self.variable2, self.variable3,
                        self.variable4) if var
        ]

    def set_attr(self, *attrs):
        self.variable1, self.variable2, self.variable3, self.variable4 = [
            attr and self.data.domain[attr.name] for attr in attrs
        ]
        self.reset_graph()

    def attr_changed(self):
        self.attrs_changed_manually.emit(self.get_disc_attr_list())
        self.reset_graph()

    def resizeEvent(self, e):
        OWWidget.resizeEvent(self, e)
        self.update_graph()

    def showEvent(self, ev):
        OWWidget.showEvent(self, ev)
        self.update_graph()

    @Inputs.data
    def set_data(self, data):
        if isinstance(data, SqlTable) and data.approx_len() > LARGE_TABLE:
            data = data.sample_time(DEFAULT_SAMPLE_TIME)

        self.closeContext()
        self.data = data

        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(
            self.data is not None and len(self.data) > 1
            and len(self.data.domain.attributes) >= 1)

        if self.data is None:
            self.discrete_data = None
            self.init_combos(None)
            return

        self.init_combos(self.data)
        self.openContext(self.data)

    @Inputs.data_subset
    def set_subset_data(self, data):
        self.subset_data = data

    # this is called by widget after setData and setSubsetData are called.
    # this way the graph is updated only once
    def handleNewSignals(self):
        self.Warning.incompatible_subset.clear()
        self.subset_indices = None
        if self.data is not None and self.subset_data:
            transformed = self.subset_data.transform(self.data.domain)
            if np.all(np.isnan(transformed.X)) \
                    and np.all(np.isnan(transformed.Y)):
                self.Warning.incompatible_subset()
            else:
                indices = {e.id for e in transformed}
                self.subset_indices = [ex.id in indices for ex in self.data]
        if self.data is not None and self.__pending_selection is not None:
            self.selection = self.__pending_selection
            self.__pending_selection = None
        else:
            self.selection = set()
        self.set_color_data()
        self.update_graph()
        self.send_selection()

    def clear_selection(self):
        self.selection = set()
        self.update_selection_rects()
        self.send_selection()

    def coloring_changed(self):
        self.vizrank.coloring_changed()
        self.update_graph()

    def reset_graph(self):
        self.clear_selection()
        self.update_graph()

    def set_color_data(self):
        if self.data is None:
            return
        self.bar_button.setEnabled(self.variable_color is not None)
        attrs = [v for v in self.model_1 if v and v is not self.variable_color]
        domain = Domain(attrs, self.variable_color, None)
        self.color_data = self.data.from_table(domain, self.data)
        self.discrete_data = self._get_discrete_data(self.color_data)
        self.vizrank.stop_and_reset()
        self.vizrank_button.setEnabled(True)
        self.coloring_changed()

    def update_selection_rects(self):
        pens = (QPen(), QPen(Qt.black, 3, Qt.DotLine))
        for i, (_, _, area) in enumerate(self.areas):
            area.setPen(pens[i in self.selection])

    def select_area(self, index, ev):
        if ev.button() != Qt.LeftButton:
            return
        if ev.modifiers() & Qt.ControlModifier:
            self.selection ^= {index}
        else:
            self.selection = {index}
        self.update_selection_rects()
        self.send_selection()

    def send_selection(self):
        if not self.selection or self.data is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(
                create_annotated_table(self.data, []))
            return
        filters = []
        self.Warning.no_cont_selection_sql.clear()
        if self.discrete_data is not self.data:
            if isinstance(self.data, SqlTable):
                self.Warning.no_cont_selection_sql()
        for i in self.selection:
            cols, vals, _ = self.areas[i]
            filters.append(
                filter.Values(
                    filter.FilterDiscrete(col, [val])
                    for col, val in zip(cols, vals)))
        if len(filters) > 1:
            filters = filter.Values(filters, conjunction=False)
        else:
            filters = filters[0]
        selection = filters(self.discrete_data)
        idset = set(selection.ids)
        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
        if self.discrete_data is not self.data:
            selection = self.data[sel_idx]

        self.Outputs.selected_data.send(selection)
        self.Outputs.annotated_data.send(
            create_annotated_table(self.data, sel_idx))

    def send_report(self):
        self.report_plot(self.canvas)

    def update_graph(self):
        spacing = self.SPACING
        bar_width = self.BAR_WIDTH

        def get_counts(attr_vals, values):
            """Calculate rectangles' widths; if all are 0, they are set to 1."""
            if not attr_vals:
                counts = [conditionaldict[val] for val in values]
            else:
                counts = [
                    conditionaldict[attr_vals + "-" + val] for val in values
                ]
            total = sum(counts)
            if total == 0:
                counts = [1] * len(values)
                total = sum(counts)
            return total, counts

        def draw_data(attr_list,
                      x0_x1,
                      y0_y1,
                      side,
                      condition,
                      total_attrs,
                      used_attrs,
                      used_vals,
                      attr_vals=""):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if conditionaldict[attr_vals] == 0:
                add_rect(x0,
                         x1,
                         y0,
                         y1,
                         "",
                         used_attrs,
                         used_vals,
                         attr_vals=attr_vals)
                # store coordinates for later drawing of labels
                draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                          used_attrs, used_vals, attr_vals)
                return

            attr = attr_list[0]
            # how much smaller rectangles do we draw
            edge = len(attr_list) * spacing
            values = get_variable_values_sorted(attr)
            if side % 2:
                values = values[::-1]  # reverse names if necessary

            if side % 2 == 0:  # we are drawing on the x axis
                # remove the space needed for separating different attr. values
                whole = max(0, (x1 - x0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (x1 - x0) / float(len(values) - 1)
            else:  # we are drawing on the y axis
                whole = max(0, (y1 - y0) - edge * (len(values) - 1))
                if whole == 0:
                    edge = (y1 - y0) / float(len(values) - 1)

            total, counts = get_counts(attr_vals, values)

            # when visualizing the third attribute and the first attribute has
            # the last value, reverse the order in which the boxes are drawn;
            # otherwise, if the last cell, nearest to the labels of the fourth
            # attribute, is empty, we wouldn't be able to position the labels
            valrange = list(range(len(values)))
            if len(attr_list + used_attrs) == 4 and len(used_attrs) == 2:
                attr1values = get_variable_values_sorted(used_attrs[0])
                if used_vals[0] == attr1values[-1]:
                    valrange = valrange[::-1]

            for i in valrange:
                start = i * edge + whole * float(sum(counts[:i]) / total)
                end = i * edge + whole * float(sum(counts[:i + 1]) / total)
                val = values[i]
                htmlval = to_html(val)
                newattrvals = attr_vals + "-" + val if attr_vals else val

                tooltip = "{}&nbsp;&nbsp;&nbsp;&nbsp;{}: <b>{}</b><br/>".format(
                    condition, attr.name, htmlval)
                attrs = used_attrs + [attr]
                vals = used_vals + [val]
                args = attrs, vals, newattrvals
                if side % 2 == 0:  # if we are moving horizontally
                    if len(attr_list) == 1:
                        add_rect(x0 + start, x0 + end, y0, y1, tooltip, *args)
                    else:
                        draw_data(attr_list[1:], (x0 + start, x0 + end),
                                  (y0, y1), side + 1, tooltip, total_attrs,
                                  *args)
                else:
                    if len(attr_list) == 1:
                        add_rect(x0, x1, y0 + start, y0 + end, tooltip, *args)
                    else:
                        draw_data(attr_list[1:], (x0, x1),
                                  (y0 + start, y0 + end), side + 1, tooltip,
                                  total_attrs, *args)
            draw_text(side, attr_list[0], (x0, x1), (y0, y1), total_attrs,
                      used_attrs, used_vals, attr_vals)

        def draw_text(side, attr, x0_x1, y0_y1, total_attrs, used_attrs,
                      used_vals, attr_vals):
            x0, x1 = x0_x1
            y0, y1 = y0_y1
            if side in drawn_sides:
                return

            # the text on the right will be drawn when we are processing
            # visualization of the last value of the first attribute
            if side == 3:
                attr1values = get_variable_values_sorted(used_attrs[0])
                if used_vals[0] != attr1values[-1]:
                    return

            if not conditionaldict[attr_vals]:
                if side not in draw_positions:
                    draw_positions[side] = (x0, x1, y0, y1)
                return
            else:
                if side in draw_positions:
                    # restore the positions of attribute values and name
                    (x0, x1, y0, y1) = draw_positions[side]

            drawn_sides.add(side)

            values = get_variable_values_sorted(attr)
            if side % 2:
                values = values[::-1]

            spaces = spacing * (total_attrs - side) * (len(values) - 1)
            width = x1 - x0 - spaces * (side % 2 == 0)
            height = y1 - y0 - spaces * (side % 2 == 1)

            # calculate position of first attribute
            currpos = 0
            total, counts = get_counts(attr_vals, values)
            aligns = [
                Qt.AlignTop | Qt.AlignHCenter, Qt.AlignRight | Qt.AlignVCenter,
                Qt.AlignBottom | Qt.AlignHCenter,
                Qt.AlignLeft | Qt.AlignVCenter
            ]
            align = aligns[side]
            for i, val in enumerate(values):
                if distributiondict[val] != 0:
                    perc = counts[i] / float(total)
                    rwidth = width * perc
                    xs = [
                        x0 + currpos + rwidth / 2, x0 - self.ATTR_VAL_OFFSET,
                        x0 + currpos + rwidth / 2, x1 + self.ATTR_VAL_OFFSET
                    ]
                    ys = [
                        y1 + self.ATTR_VAL_OFFSET,
                        y0 + currpos + height * 0.5 * perc,
                        y0 - self.ATTR_VAL_OFFSET,
                        y0 + currpos + height * 0.5 * perc
                    ]

                    CanvasText(self.canvas,
                               val,
                               xs[side],
                               ys[side],
                               align,
                               max_width=rwidth if side == 0 else None)
                    space = height if side % 2 else width
                    currpos += perc * space + spacing * (total_attrs - side)

            xs = [
                x0 + (x1 - x0) / 2, x0 - max_ylabel_w1 - self.ATTR_VAL_OFFSET,
                x0 + (x1 - x0) / 2, x1 + max_ylabel_w2 + self.ATTR_VAL_OFFSET
            ]
            ys = [
                y1 + self.ATTR_VAL_OFFSET + self.ATTR_NAME_OFFSET,
                y0 + (y1 - y0) / 2,
                y0 - self.ATTR_VAL_OFFSET - self.ATTR_NAME_OFFSET,
                y0 + (y1 - y0) / 2
            ]
            CanvasText(self.canvas,
                       attr.name,
                       xs[side],
                       ys[side],
                       align,
                       bold=True,
                       vertical=side % 2)

        def add_rect(x0,
                     x1,
                     y0,
                     y1,
                     condition,
                     used_attrs,
                     used_vals,
                     attr_vals=""):
            area_index = len(self.areas)
            x1 += (x0 == x1)
            y1 += (y0 == y1)
            # rectangles of width and height 1 are not shown - increase
            y1 += (x1 - x0 + y1 - y0 == 2)
            colors = class_var and [QColor(*col) for col in class_var.colors]

            def select_area(_, ev):
                self.select_area(area_index, ev)

            def rect(x, y, w, h, z, pen_color=None, brush_color=None, **args):
                if pen_color is None:
                    return CanvasRectangle(self.canvas,
                                           x,
                                           y,
                                           w,
                                           h,
                                           z=z,
                                           onclick=select_area,
                                           **args)
                if brush_color is None:
                    brush_color = pen_color
                return CanvasRectangle(self.canvas,
                                       x,
                                       y,
                                       w,
                                       h,
                                       pen_color,
                                       brush_color,
                                       z=z,
                                       onclick=select_area,
                                       **args)

            def line(x1, y1, x2, y2):
                r = QGraphicsLineItem(x1, y1, x2, y2, None)
                self.canvas.addItem(r)
                r.setPen(QPen(Qt.white, 2))
                r.setZValue(30)

            outer_rect = rect(x0, y0, x1 - x0, y1 - y0, 30)
            self.areas.append((used_attrs, used_vals, outer_rect))
            if not conditionaldict[attr_vals]:
                return

            if self.variable_color is None:
                s = sum(apriori_dists[0])
                expected = s * reduce(
                    mul, (apriori_dists[i][used_vals[i]] / float(s)
                          for i in range(len(used_vals))))
                actual = conditionaldict[attr_vals]
                pearson = float((actual - expected) / sqrt(expected))
                if pearson == 0:
                    ind = 0
                else:
                    ind = max(0, min(int(log(abs(pearson), 2)), 3))
                color = [self.RED_COLORS, self.BLUE_COLORS][pearson > 0][ind]
                rect(x0, y0, x1 - x0, y1 - y0, -20, color)
                outer_rect.setToolTip(
                    condition + "<hr/>" + "Expected instances: %.1f<br>"
                    "Actual instances: %d<br>"
                    "Standardized (Pearson) residual: %.1f" %
                    (expected, conditionaldict[attr_vals], pearson))
            else:
                cls_values = get_variable_values_sorted(class_var)
                prior = get_distribution(data, class_var.name)
                total = 0
                for i, value in enumerate(cls_values):
                    val = conditionaldict[attr_vals + "-" + value]
                    if val == 0:
                        continue
                    if i == len(cls_values) - 1:
                        v = y1 - y0 - total
                    else:
                        v = ((y1 - y0) * val) / conditionaldict[attr_vals]
                    rect(x0, y0 + total, x1 - x0, v, -20, colors[i])
                    total += v

                if self.use_boxes and \
                        abs(x1 - x0) > bar_width and abs(y1 - y0) > bar_width:
                    total = 0
                    line(x0 + bar_width, y0, x0 + bar_width, y1)
                    n = sum(prior)
                    for i, (val, color) in enumerate(zip(prior, colors)):
                        if i == len(prior) - 1:
                            h = y1 - y0 - total
                        else:
                            h = (y1 - y0) * val / n
                        rect(x0, y0 + total, bar_width, h, 20, color)
                        total += h

                if conditionalsubsetdict:
                    if conditionalsubsetdict[attr_vals]:
                        if self.subset_indices is not None:
                            line(x1 - bar_width, y0, x1 - bar_width, y1)
                            total = 0
                            n = conditionalsubsetdict[attr_vals]
                            if n:
                                for i, (cls, color) in \
                                        enumerate(zip(cls_values, colors)):
                                    val = conditionalsubsetdict[attr_vals +
                                                                "-" + cls]
                                    if val == 0:
                                        continue
                                    if i == len(prior) - 1:
                                        v = y1 - y0 - total
                                    else:
                                        v = ((y1 - y0) * val) / n
                                    rect(x1 - bar_width, y0 + total, bar_width,
                                         v, 15, color)
                                    total += v

                actual = [
                    conditionaldict[attr_vals + "-" + cls_values[i]]
                    for i in range(len(prior))
                ]
                n_actual = sum(actual)
                if n_actual > 0:
                    apriori = [prior[key] for key in cls_values]
                    n_apriori = sum(apriori)
                    text = "<br/>".join(
                        "<b>%s</b>: %d / %.1f%% (Expected %.1f / %.1f%%)" %
                        (cls, act, 100.0 * act / n_actual,
                         apr / n_apriori * n_actual, 100.0 * apr / n_apriori)
                        for cls, act, apr in zip(cls_values, actual, apriori))
                else:
                    text = ""
                outer_rect.setToolTip("{}<hr>Instances: {}<br><br>{}".format(
                    condition, n_actual, text[:-4]))

        def create_legend():
            if self.variable_color is None:
                names = [
                    "<-8", "-8:-4", "-4:-2", "-2:2", "2:4", "4:8", ">8",
                    "Residuals:"
                ]
                colors = self.RED_COLORS[::-1] + self.BLUE_COLORS[1:]
                edges = repeat(Qt.black)
            else:
                names = get_variable_values_sorted(class_var)
                edges = colors = [QColor(*col) for col in class_var.colors]

            items = []
            size = 8
            for name, color, edgecolor in zip(names, colors, edges):
                item = QGraphicsItemGroup()
                item.addToGroup(
                    CanvasRectangle(None, -size / 2, -size / 2, size, size,
                                    edgecolor, color))
                item.addToGroup(
                    CanvasText(None, name, size, 0, Qt.AlignVCenter))
                items.append(item)
            return wrap_legend_items(items,
                                     hspacing=20,
                                     vspacing=16 + size,
                                     max_width=self.canvas_view.width() - xoff)

        self.canvas.clear()
        self.areas = []

        data = self.discrete_data
        if data is None:
            return
        attr_list = self.get_disc_attr_list()
        class_var = data.domain.class_var
        # TODO: check this
        # data = Preprocessor_dropMissing(data)

        unique = [v.name for v in set(attr_list + [class_var]) if v]
        if len(data[:, unique]) == 0:
            self.Warning.no_valid_data()
            return
        else:
            self.Warning.no_valid_data.clear()

        attrs = [attr for attr in attr_list if not attr.values]
        if attrs:
            CanvasText(self.canvas,
                       "Feature {} has no values".format(attrs[0]),
                       (self.canvas_view.width() - 120) / 2,
                       self.canvas_view.height() / 2)
            return
        if self.variable_color is None:
            apriori_dists = [
                get_distribution(data, attr) for attr in attr_list
            ]
        else:
            apriori_dists = []

        def get_max_label_width(attr):
            values = get_variable_values_sorted(attr)
            maxw = 0
            for val in values:
                t = CanvasText(self.canvas, val, 0, 0, bold=0, show=False)
                maxw = max(int(t.boundingRect().width()), maxw)
            return maxw

        xoff = 20

        # get the maximum width of rectangle
        width = 20
        max_ylabel_w1 = max_ylabel_w2 = 0
        if len(attr_list) > 1:
            text = CanvasText(self.canvas, attr_list[1].name, bold=1, show=0)
            max_ylabel_w1 = min(get_max_label_width(attr_list[1]), 150)
            width = 5 + text.boundingRect().height() + \
                self.ATTR_VAL_OFFSET + max_ylabel_w1
            xoff = width
            if len(attr_list) == 4:
                text = CanvasText(self.canvas,
                                  attr_list[3].name,
                                  bold=1,
                                  show=0)
                max_ylabel_w2 = min(get_max_label_width(attr_list[3]), 150)
                width += text.boundingRect().height() + \
                    self.ATTR_VAL_OFFSET + max_ylabel_w2 - 10

        legend = create_legend()

        # get the maximum height of rectangle
        yoff = 45
        legendoff = yoff + self.ATTR_NAME_OFFSET + self.ATTR_VAL_OFFSET + 35
        square_size = min(
            self.canvas_view.width() - width - 20,
            self.canvas_view.height() - legendoff -
            legend.boundingRect().height())

        if square_size < 0:
            return  # canvas is too small to draw rectangles
        self.canvas_view.setSceneRect(0, 0, self.canvas_view.width(),
                                      self.canvas_view.height())

        drawn_sides = set()
        draw_positions = {}

        conditionaldict, distributiondict = \
            get_conditional_distribution(data, attr_list)
        conditionalsubsetdict = None
        if self.subset_indices:
            conditionalsubsetdict, _ = get_conditional_distribution(
                self.discrete_data[self.subset_indices], attr_list)

        # draw rectangles
        draw_data(attr_list, (xoff, xoff + square_size),
                  (yoff, yoff + square_size), 0, "", len(attr_list), [], [])

        self.canvas.addItem(legend)
        legend.setPos(
            xoff - legend.boundingRect().x() +
            max(0, (square_size - legend.boundingRect().width()) / 2),
            legendoff + square_size)
        self.update_selection_rects()

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            settings.migrate_str_to_variable(context,
                                             none_placeholder="(None)")
Ejemplo n.º 22
0
class OWLinearProjection(OWAnchorProjectionWidget):
    name = "Linear Projection"
    description = "A multi-axis projection of data onto " \
                  "a two-dimensional plane."
    icon = "icons/LinearProjection.svg"
    priority = 240
    keywords = []

    Projection_name = {Placement.Circular: "Circular Placement",
                       Placement.LDA: "Linear Discriminant Analysis",
                       Placement.PCA: "Principal Component Analysis"}

    settings_version = 6

    placement = Setting(Placement.Circular)
    selected_vars = ContextSetting([])
    vizrank = SettingProvider(LinearProjectionVizRank)
    GRAPH_CLASS = OWLinProjGraph
    graph = SettingProvider(OWLinProjGraph)

    left_side_scrolling = True

    class Error(OWAnchorProjectionWidget.Error):
        no_cont_features = Msg("Plotting requires numeric features")

    def _add_controls(self):
        self._add_controls_variables()
        self._add_controls_placement()
        super()._add_controls()
        self.gui.add_control(
            self._effects_box, gui.hSlider, "Hide radius:", master=self.graph,
            value="hide_radius", minValue=0, maxValue=100, step=10,
            createLabel=False, callback=self.__radius_slider_changed
        )
        self.controlArea.layout().removeWidget(self.control_area_stretch)
        self.control_area_stretch.setParent(None)

    def _add_controls_variables(self):
        self.model_selected = VariableSelectionModel(self.selected_vars)
        variables_selection(self.controlArea, self, self.model_selected)
        self.model_selected.selection_changed.connect(
            self.__model_selected_changed)
        self.vizrank, self.btn_vizrank = LinearProjectionVizRank.add_vizrank(
            None, self, "Suggest Features", self.__vizrank_set_attrs)
        self.controlArea.layout().addWidget(self.btn_vizrank)

    def _add_controls_placement(self):
        box = gui.widgetBox(
            self.controlArea, True,
            sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum)
        )
        self.radio_placement = gui.radioButtonsInBox(
            box, self, "placement",
            btnLabels=[self.Projection_name[x] for x in Placement],
            callback=self.__placement_radio_changed
        )

    @property
    def continuous_variables(self):
        if self.data is None or self.data.domain is None:
            return []
        dom = self.data.domain
        return [v for v in chain(dom.variables, dom.metas) if v.is_continuous]

    @property
    def effective_variables(self):
        return self.selected_vars

    def __vizrank_set_attrs(self, attrs):
        if not attrs:
            return
        self.selected_vars[:] = attrs
        # Ugly, but the alternative is to have yet another signal to which
        # the view will have to connect
        self.model_selected.selection_changed.emit()

    def __model_selected_changed(self):
        self.projection = None
        self._check_options()
        self.init_projection()
        self.setup_plot()
        self.commit()

    def __placement_radio_changed(self):
        self.controls.graph.hide_radius.setEnabled(
            self.placement != Placement.Circular)
        self.projection = self.projector = None
        self._init_vizrank()
        self.init_projection()
        self.setup_plot()
        self.commit()

    def __radius_slider_changed(self):
        self.graph.update_radius()

    def colors_changed(self):
        super().colors_changed()
        self._init_vizrank()

    def set_data(self, data):
        super().set_data(data)
        self._check_options()
        self._init_vizrank()
        self.init_projection()

    def _check_options(self):
        buttons = self.radio_placement.buttons
        for btn in buttons:
            btn.setEnabled(True)

        if self.data is not None:
            has_discrete_class = self.data.domain.has_discrete_class
            if not has_discrete_class or len(np.unique(self.data.Y)) < 3:
                buttons[Placement.LDA].setEnabled(False)
                if self.placement == Placement.LDA:
                    self.placement = Placement.Circular

        self.controls.graph.hide_radius.setEnabled(
            self.placement != Placement.Circular)

    def _init_vizrank(self):
        is_enabled, msg = False, ""
        if self.data is None:
            msg = "There is no data."
        elif self.attr_color is None:
            msg = "Color variable has to be selected"
        elif self.attr_color.is_continuous and \
                self.placement == Placement.LDA:
            msg = "Suggest Features does not work for Linear " \
                  "Discriminant Analysis Projection when " \
                  "continuous color variable is selected."
        elif len([v for v in self.continuous_variables
                  if v is not self.attr_color]) < 3:
            msg = "Not enough available continuous variables"
        elif np.sum(np.all(np.isfinite(self.data.X), axis=1)) < 2:
            msg = "Not enough valid data instances"
        else:
            is_enabled = not np.isnan(self.data.get_column_view(
                self.attr_color)[0].astype(float)).all()
        self.btn_vizrank.setToolTip(msg)
        self.btn_vizrank.setEnabled(is_enabled)
        if is_enabled:
            self.vizrank.initialize()

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            if not len(self.continuous_variables):
                error(self.Error.no_cont_features)

    def init_attr_values(self):
        super().init_attr_values()
        self.selected_vars[:] = self.continuous_variables[:3]
        self.model_selected[:] = self.continuous_variables

    def init_projection(self):
        if self.placement == Placement.Circular:
            self.projector = CircularPlacement()
        elif self.placement == Placement.LDA:
            self.projector = LDA(solver="eigen", n_components=2)
        elif self.placement == Placement.PCA:
            self.projector = PCA(n_components=2)
            self.projector.component = 2
            self.projector.preprocessors = PCA.preprocessors + [Normalize()]

        super().init_projection()

    def get_coordinates_data(self):
        def normalized(a):
            span = np.max(a, axis=0) - np.min(a, axis=0)
            span[span == 0] = 1
            return (a - np.mean(a, axis=0)) / span

        embedding = self.get_embedding()
        if embedding is None:
            return None, None
        norm_emb = normalized(embedding[self.valid_data])
        return (norm_emb.ravel(), np.zeros(len(norm_emb), dtype=float)) \
            if embedding.shape[1] == 1 else norm_emb.T

    def _get_send_report_caption(self):
        def projection_name():
            return self.Projection_name[self.placement]

        return report.render_items_vert((
            ("Projection", projection_name()),
            ("Color", self._get_caption_var_name(self.attr_color)),
            ("Label", self._get_caption_var_name(self.attr_label)),
            ("Shape", self._get_caption_var_name(self.attr_shape)),
            ("Size", self._get_caption_var_name(self.attr_size)),
            ("Jittering", self.graph.jitter_size != 0 and
             "{} %".format(self.graph.jitter_size))))

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_["point_width"] = settings_["point_size"]
        if version < 3:
            settings_graph = {}
            settings_graph["jitter_size"] = settings_["jitter_value"]
            settings_graph["point_width"] = settings_["point_width"]
            settings_graph["alpha_value"] = settings_["alpha_value"]
            settings_graph["class_density"] = settings_["class_density"]
            settings_["graph"] = settings_graph
        if version < 4:
            if "radius" in settings_:
                settings_["graph"]["hide_radius"] = settings_["radius"]
            if "selection_indices" in settings_ and \
                    settings_["selection_indices"] is not None:
                selection = settings_["selection_indices"]
                settings_["selection"] = [(i, 1) for i, selected in
                                          enumerate(selection) if selected]
        if version < 5:
            if "placement" in settings_ and \
                    settings_["placement"] not in Placement:
                settings_["placement"] = Placement.Circular

    @classmethod
    def migrate_context(cls, context, version):
        values = context.values
        if version < 2:
            domain = context.ordered_domain
            c_domain = [t for t in context.ordered_domain if t[1] == 2]
            d_domain = [t for t in context.ordered_domain if t[1] == 1]
            for d, old_val, new_val in ((domain, "color_index", "attr_color"),
                                        (d_domain, "shape_index", "attr_shape"),
                                        (c_domain, "size_index", "attr_size")):
                index = context.values[old_val][0] - 1
                values[new_val] = (d[index][0], d[index][1] + 100) \
                    if 0 <= index < len(d) else None
        if version < 3:
            values["graph"] = {
                "attr_color": values["attr_color"],
                "attr_shape": values["attr_shape"],
                "attr_size": values["attr_size"]
            }
        if version == 3:
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
        if version < 6 and "selected_vars" in values:
            values["selected_vars"] = (values["selected_vars"], -3)

    # for backward compatibility with settings < 6, pull the enum from global
    # namespace into class
    Placement = Placement
Ejemplo n.º 23
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read data from an input file or network " \
                  "and send a data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["data", "file", "load", "read"]
    outputs = [
        widget.OutputSignal(
            "Data",
            Table,
            doc="Attribute-valued data set read from the input file.")
    ]

    want_main_area = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]
    SIZE_LIMIT = 1e7
    LOCAL_FILE, URL = range(2)

    settingsHandler = PerfectDomainContextHandler()

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
        RecentPath("", "sample-datasets", "heart_disease.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    variables = ContextSetting([])

    dlg_formats = ("All readable files ({});;".format(
        '*' + ' *'.join(FileFormat.readers.keys())) + ";;".join(
            "{} (*{})".format(f.DESCRIPTION, ' *'.join(f.EXTENSIONS))
            for f in sorted(set(FileFormat.readers.values()),
                            key=list(FileFormat.readers.values()).index)))

    domain_editor = SettingProvider(DomainEditor)

    class Warning(widget.OWWidget.Warning):
        file_too_big = widget.Msg(
            "The file is too large to load automatically."
            " Press Reload to load.")

    class Error(widget.OWWidget.Error):
        file_not_found = widget.Msg("File not found.")

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self.reader = None

        layout = QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data,
                                addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File:", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(None,
                                 self,
                                 '...',
                                 callback=self.browse_file,
                                 autoDefault=False)
        file_button.setIcon(self.style().standardIcon(QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(None,
                                   self,
                                   "Reload",
                                   callback=self.load_data,
                                   autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(
            None,
            self,
            "xls_sheet",
            callback=self.select_sheet,
            sendSelectedValue=True,
        )
        self.sheet_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(self.sheet_label, Qt.AlignLeft)
        self.sheet_box.layout().addWidget(self.sheet_combo, Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL:", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, Qt.AlignVCenter)

        self.url_combo = url_combo = QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setLineEdit(LineEditSelectOnFocus())
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)

        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, "Columns (Double click to edit)")
        self.domain_editor = DomainEditor(self)
        self.editor_model = self.domain_editor.model()
        box.layout().addWidget(self.domain_editor)

        box = gui.hBox(self.controlArea)
        gui.button(box,
                   self,
                   "Browse documentation data sets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)
        box.layout().addWidget(self.report_button)
        self.report_button.setFixedWidth(170)

        self.apply_button = gui.button(box,
                                       self,
                                       "Apply",
                                       callback=self.apply_domain_edit)
        self.apply_button.setEnabled(False)
        self.apply_button.setFixedWidth(170)
        self.editor_model.dataChanged.connect(
            lambda: self.apply_button.setEnabled(True))

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)

        self.setAcceptDrops(True)

        if self.source == self.LOCAL_FILE:
            last_path = self.last_path()
            if last_path and os.path.exists(last_path) and \
                    os.path.getsize(last_path) > self.SIZE_LIMIT:
                self.Warning.file_too_big()
                return

        QTimer.singleShot(0, self.load_data)

    def sizeHint(self):
        return QSize(600, 550)

    def select_file(self, n):
        assert n < len(self.recent_paths)
        super().select_file(n)
        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def select_sheet(self):
        self.recent_paths[0].sheet = self.sheet_combo.currentText()
        self.load_data()

    def _url_set(self):
        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation data sets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        filename, _ = QFileDialog.getOpenFileName(self,
                                                  'Open Orange Data File',
                                                  start_file, self.dlg_formats)
        if not filename:
            return
        self.add_path(filename)
        self.source = self.LOCAL_FILE
        self.load_data()

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        # We need to catch any exception type since anything can happen in
        # file readers
        # pylint: disable=broad-except
        self.closeContext()
        self.domain_editor.set_domain(None)
        self.apply_button.setEnabled(False)
        self.clear_messages()
        self.set_file_list()
        if self.last_path() and not os.path.exists(self.last_path()):
            self.Error.file_not_found()
            self.send("Data", None)
            self.info.setText("No data.")
            return

        error = None
        try:
            self.reader = self._get_reader()
            if self.reader is None:
                self.data = None
                self.send("Data", None)
                self.info.setText("No data.")
                self.sheet_box.hide()
                return
        except Exception as ex:
            error = ex

        if not error:
            self._update_sheet_combo()
            with catch_warnings(record=True) as warnings:
                try:
                    data = self.reader.read()
                except Exception as ex:
                    log.exception(ex)
                    error = ex
                self.warning(warnings[-1].message.args[0] if warnings else '')

        if error:
            self.data = None
            self.send("Data", None)
            self.info.setText("An error occurred:\n{}".format(error))
            self.sheet_box.hide()
            return

        self.info.setText(self._describe(data))

        self.loaded_file = self.last_path()
        add_origin(data, self.loaded_file)
        self.data = data
        self.openContext(data.domain)
        self.apply_domain_edit()  # sends data

    def _get_reader(self):
        """

        Returns
        -------
        FileFormat
        """
        if self.source == self.LOCAL_FILE:
            reader = FileFormat.get_reader(self.last_path())
            if self.recent_paths and self.recent_paths[0].sheet:
                reader.select_sheet(self.recent_paths[0].sheet)
            return reader
        elif self.source == self.URL:
            url = self.url_combo.currentText().strip()
            if url:
                return UrlReader(url)

    def _update_sheet_combo(self):
        if len(self.reader.sheets) < 2:
            self.sheet_box.hide()
            self.reader.select_sheet(None)
            return

        self.sheet_combo.clear()
        self.sheet_combo.addItems(self.reader.sheets)
        self._select_active_sheet()
        self.sheet_box.show()

    def _select_active_sheet(self):
        if self.reader.sheet:
            try:
                idx = self.reader.sheets.index(self.reader.sheet)
                self.sheet_combo.setCurrentIndex(idx)
            except ValueError:
                # Requested sheet does not exist in this file
                self.reader.select_sheet(None)
        else:
            self.sheet_combo.setCurrentIndex(0)

    def _describe(self, table):
        domain = table.domain
        text = ""

        attrs = getattr(table, "attributes", {})
        descs = [
            attrs[desc] for desc in ("Name", "Description") if desc in attrs
        ]
        if len(descs) == 2:
            descs[0] = "<b>{}</b>".format(descs[0])
        if descs:
            text += "<p>{}</p>".format("<br/>".join(descs))

        text += "<p>{} instance(s), {} feature(s), {} meta attribute(s)".\
            format(len(table), len(domain.attributes), len(domain.metas))
        if domain.has_continuous_class:
            text += "<br/>Regression; numerical class."
        elif domain.has_discrete_class:
            text += "<br/>Classification; discrete class with {} values.".\
                format(len(domain.class_var.values))
        elif table.domain.class_vars:
            text += "<br/>Multi-target; {} target variables.".format(
                len(table.domain.class_vars))
        else:
            text += "<br/>Data has no target variable."
        text += "</p>"

        if 'Timestamp' in table.domain:
            # Google Forms uses this header to timestamp responses
            text += '<p>First entry: {}<br/>Last entry: {}</p>'.format(
                table[0, 'Timestamp'], table[-1, 'Timestamp'])
        return text

    def storeSpecificSettings(self):
        self.current_context.modified_variables = self.variables[:]

    def retrieveSpecificSettings(self):
        if hasattr(self.current_context, "modified_variables"):
            self.variables[:] = self.current_context.modified_variables

    def apply_domain_edit(self):
        if self.data is not None:
            domain, cols = self.domain_editor.get_domain(
                self.data.domain, self.data)
            X, y, m = cols
            X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
            y = np.array(y).T if len(y) else None
            dtpe = object if any(
                isinstance(m, StringVariable) for m in domain.metas) else float
            m = np.array(m, dtype=dtpe).T if len(m) else None
            table = Table.from_numpy(domain, X, y, m, self.data.W)
            table.name = self.data.name
            table.ids = np.array(self.data.ids)
            table.attributes = getattr(self.data, 'attributes', {})
        else:
            table = self.data

        self.send("Data", table)
        self.apply_button.setEnabled(False)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~" + os.path.sep + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += " ({})".format(self.sheet_combo.currentText())
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    def dragEnterEvent(self, event):
        """Accept drops of valid file urls"""
        urls = event.mimeData().urls()
        if urls:
            try:
                FileFormat.get_reader(
                    OSX_NSURL_toLocalFile(urls[0]) or urls[0].toLocalFile())
                event.acceptProposedAction()
            except IOError:
                pass

    def dropEvent(self, event):
        """Handle file drops"""
        urls = event.mimeData().urls()
        if urls:
            self.add_path(
                OSX_NSURL_toLocalFile(urls[0])
                or urls[0].toLocalFile())  # add first file
            self.source = self.LOCAL_FILE
            self.load_data()
Ejemplo n.º 24
0
class OWMDS(OWWidget):
    name = "MDS"
    description = "Two-dimensional data projection by multidimensional " \
                  "scaling constructed from a distance matrix."
    icon = "icons/MDS.svg"

    class Inputs:
        data = Input("Data", Orange.data.Table, default=True)
        distances = Input("Distances", Orange.misc.DistMatrix)
        data_subset = Input("Data Subset", Orange.data.Table)

    class Outputs:
        selected_data = Output("Selected Data",
                               Orange.data.Table,
                               default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table)

    settings_version = 2

    #: Initialization type
    PCA, Random = 0, 1

    #: Refresh rate
    RefreshRate = [("Every iteration", 1), ("Every 5 steps", 5),
                   ("Every 10 steps", 10), ("Every 25 steps", 25),
                   ("Every 50 steps", 50), ("None", -1)]

    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    settingsHandler = settings.DomainContextHandler()

    max_iter = settings.Setting(300)
    initialization = settings.Setting(PCA)
    refresh_rate = settings.Setting(3)

    # output embedding role.
    NoRole, AttrRole, AddAttrRole, MetaRole = 0, 1, 2, 3

    auto_commit = settings.Setting(True)

    selection_indices = settings.Setting(None, schema_only=True)

    #: Percentage of all pairs displayed (ranges from 0 to 20)
    connected_pairs = settings.Setting(5)

    legend_anchor = settings.Setting(((1, 0), (1, 0)))

    graph = SettingProvider(OWMDSGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Error(OWWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        matrix_too_small = Msg("Input matrix must be at least 2x2")
        no_attributes = Msg("Data has no attributes")
        mismatching_dimensions = \
            Msg("Data and distances dimensions do not match.")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")

    def __init__(self):
        super().__init__()
        #: Input dissimilarity matrix
        self.matrix = None  # type: Optional[Orange.misc.DistMatrix]
        #: Effective data used for plot styling/annotations. Can be from the
        #: input signal (`self.signal_data`) or the input matrix
        #: (`self.matrix.data`)
        self.data = None  # type: Optional[Orange.data.Table]
        #: Input subset data table
        self.subset_data = None  # type: Optional[Orange.data.Table]
        #: Data table from the `self.matrix.row_items` (if present)
        self.matrix_data = None  # type: Optional[Orange.data.Table]
        #: Input data table
        self.signal_data = None

        self._similar_pairs = None
        self._subset_mask = None  # type: Optional[np.ndarray]
        self._invalidated = False
        self.effective_matrix = None
        self._curve = None

        self.variable_x = ContinuousVariable("mds-x")
        self.variable_y = ContinuousVariable("mds-y")

        self.__update_loop = None
        # timer for scheduling updates
        self.__timer = QTimer(self, singleShot=True, interval=0)
        self.__timer.timeout.connect(self.__next_step)
        self.__state = OWMDS.Waiting
        self.__in_next_step = False
        self.__draw_similar_pairs = False

        box = gui.vBox(self.controlArea, "MDS Optimization")
        form = QFormLayout(labelAlignment=Qt.AlignLeft,
                           formAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
                           verticalSpacing=10)

        form.addRow("Max iterations:",
                    gui.spin(box, self, "max_iter", 10, 10**4, step=1))

        form.addRow(
            "Initialization:",
            gui.radioButtons(box,
                             self,
                             "initialization",
                             btnLabels=("PCA (Torgerson)", "Random"),
                             callback=self.__invalidate_embedding))

        box.layout().addLayout(form)
        form.addRow(
            "Refresh:",
            gui.comboBox(box,
                         self,
                         "refresh_rate",
                         items=[t for t, _ in OWMDS.RefreshRate],
                         callback=self.__invalidate_refresh))
        gui.separator(box, 10)
        self.runbutton = gui.button(box,
                                    self,
                                    "Run",
                                    callback=self._toggle_run)

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWMDSGraph(self,
                                box,
                                "MDSGraph",
                                view_box=MDSInteractiveViewBox)
        box.layout().addWidget(self.graph.plot_widget)
        self.plot = self.graph.plot_widget

        g = self.graph.gui
        box = g.point_properties_box(self.controlArea)
        self.models = g.points_models
        self.models[2].order = \
            self.models[2].order[:1] + ("Stress", ) + self.models[2].order[1:]

        gui.hSlider(box,
                    self,
                    "connected_pairs",
                    label="Show similar pairs:",
                    minValue=0,
                    maxValue=20,
                    createLabel=False,
                    callback=self._on_connected_changed)
        g.add_widgets(ids=[g.JitterSizeSlider], widget=box)

        box = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([
            g.ShowLegend, g.ToolTipShowsAll, g.ClassDensity,
            g.LabelOnlySelected
        ], box)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        palette = self.graph.plot_widget.palette()
        self.graph.set_palette(palette)

        gui.rubber(self.controlArea)

        self.graph.box_zoom_select(self.controlArea)

        gui.auto_commit(box,
                        self,
                        "auto_commit",
                        "Send Selected",
                        checkbox_label="Send selected automatically",
                        box=None)

        self.plot.getPlotItem().hideButtons()
        self.plot.setRenderHint(QPainter.Antialiasing)

        self.graph.jitter_continuous = True
        self._initialize()

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()
        self.connect_pairs()

    def update_colors(self):
        pass

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def init_attr_values(self):
        domain = self.data and len(self.data) and self.data.domain or None
        for model in self.models:
            model.set_domain(domain)
        self.graph.attr_color = self.data.domain.class_var if domain else None
        self.graph.attr_shape = None
        self.graph.attr_size = None
        self.graph.attr_label = None

    def prepare_data(self):
        pass

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x, self.variable_y, True)

    def selection_changed(self):
        self.commit()

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set the input data set.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        if data is not None and len(data) < 2:
            self.Error.not_enough_rows()
            data = None
        else:
            self.Error.not_enough_rows.clear()

        self.signal_data = data

        if self.matrix is not None and data is not None and len(
                self.matrix) == len(data):
            self.closeContext()
            self.data = data
            self.init_attr_values()
            self.openContext(data)
        else:
            self._invalidated = True

    @Inputs.distances
    def set_disimilarity(self, matrix):
        """Set the dissimilarity (distance) matrix.

        Parameters
        ----------
        matrix : Optional[Orange.misc.DistMatrix]
        """

        if matrix is not None and len(matrix) < 2:
            self.Error.matrix_too_small()
            matrix = None
        else:
            self.Error.matrix_too_small.clear()

        self.matrix = matrix
        if matrix is not None and matrix.row_items:
            self.matrix_data = matrix.row_items
        if matrix is None:
            self.matrix_data = None
        self._invalidated = True

    @Inputs.data_subset
    def set_subset_data(self, subset_data):
        """Set a subset of `data` input to highlight in the plot.

        Parameters
        ----------
        subset_data: Optional[Orange.data.Table]
        """
        self.subset_data = subset_data
        # invalidate the pen/brush when the subset is changed
        self._subset_mask = None  # type: Optional[np.ndarray]
        self.controls.graph.alpha_value.setEnabled(subset_data is None)

    def _clear(self):
        self._similar_pairs = None

        self.__set_update_loop(None)
        self.__state = OWMDS.Waiting

    def _clear_plot(self):
        self.graph.plot_widget.clear()

    def _initialize(self):
        # clear everything
        self.closeContext()
        self._clear()
        self.Error.clear()
        self.data = None
        self.effective_matrix = None
        self.embedding = None
        self.init_attr_values()

        # if no data nor matrix is present reset plot
        if self.signal_data is None and self.matrix is None:
            return

        if self.signal_data is not None and self.matrix is not None and \
                len(self.signal_data) != len(self.matrix):
            self.Error.mismatching_dimensions()
            self._update_plot()
            return

        if self.signal_data is not None:
            self.data = self.signal_data
        elif self.matrix_data is not None:
            self.data = self.matrix_data

        if self.matrix is not None:
            self.effective_matrix = self.matrix
            if self.matrix.axis == 0 and self.data is self.matrix_data:
                self.data = None
        elif self.data.domain.attributes:
            preprocessed_data = Orange.projection.MDS().preprocess(self.data)
            self.effective_matrix = Orange.distance.Euclidean(
                preprocessed_data)
        else:
            self.Error.no_attributes()
            return

        self.init_attr_values()
        self.openContext(self.data)

    def _toggle_run(self):
        if self.__state == OWMDS.Running:
            self.stop()
            self._invalidate_output()
        else:
            self.start()

    def start(self):
        if self.__state == OWMDS.Running:
            return
        elif self.__state == OWMDS.Finished:
            # Resume/continue from a previous run
            self.__start()
        elif self.__state == OWMDS.Waiting and \
                self.effective_matrix is not None:
            self.__start()

    def stop(self):
        if self.__state == OWMDS.Running:
            self.__set_update_loop(None)

    def __start(self):
        self.__draw_similar_pairs = False
        X = self.effective_matrix
        init = self.embedding

        # number of iterations per single GUI update step
        _, step_size = OWMDS.RefreshRate[self.refresh_rate]
        if step_size == -1:
            step_size = self.max_iter

        def update_loop(X, max_iter, step, init):
            """
            return an iterator over successive improved MDS point embeddings.
            """
            # NOTE: this code MUST NOT call into QApplication.processEvents
            done = False
            iterations_done = 0
            oldstress = np.finfo(np.float).max
            init_type = "PCA" if self.initialization == OWMDS.PCA else "random"

            while not done:
                step_iter = min(max_iter - iterations_done, step)
                mds = Orange.projection.MDS(dissimilarity="precomputed",
                                            n_components=2,
                                            n_init=1,
                                            max_iter=step_iter,
                                            init_type=init_type,
                                            init_data=init)

                mdsfit = mds(X)
                iterations_done += step_iter

                embedding, stress = mdsfit.embedding_, mdsfit.stress_
                stress /= np.sqrt(np.sum(embedding**2, axis=1)).sum()

                if iterations_done >= max_iter:
                    done = True
                elif (oldstress - stress) < mds.params["eps"]:
                    done = True
                init = embedding
                oldstress = stress

                yield embedding, mdsfit.stress_, iterations_done / max_iter

        self.__set_update_loop(update_loop(X, self.max_iter, step_size, init))
        self.progressBarInit(processEvents=None)

    def __set_update_loop(self, loop):
        """
        Set the update `loop` coroutine.

        The `loop` is a generator yielding `(embedding, stress, progress)`
        tuples where `embedding` is a `(N, 2) ndarray` of current updated
        MDS points, `stress` is the current stress and `progress` a float
        ratio (0 <= progress <= 1)

        If an existing update coroutine loop is already in place it is
        interrupted (i.e. closed).

        .. note::
            The `loop` must not explicitly yield control flow to the event
            loop (i.e. call `QApplication.processEvents`)

        """
        if self.__update_loop is not None:
            self.__update_loop.close()
            self.__update_loop = None
            self.progressBarFinished(processEvents=None)

        self.__update_loop = loop

        if loop is not None:
            self.setBlocking(True)
            self.progressBarInit(processEvents=None)
            self.setStatusMessage("Running")
            self.runbutton.setText("Stop")
            self.__state = OWMDS.Running
            self.__timer.start()
        else:
            self.setBlocking(False)
            self.setStatusMessage("")
            self.runbutton.setText("Start")
            self.__state = OWMDS.Finished
            self.__timer.stop()

    def __next_step(self):
        if self.__update_loop is None:
            return

        assert not self.__in_next_step
        self.__in_next_step = True

        loop = self.__update_loop
        self.Error.out_of_memory.clear()
        try:
            embedding, _, progress = next(self.__update_loop)
            assert self.__update_loop is loop
        except StopIteration:
            self.__set_update_loop(None)
            self.unconditional_commit()
            self.__draw_similar_pairs = True
            self._update_plot()
        except MemoryError:
            self.Error.out_of_memory()
            self.__set_update_loop(None)
            self.__draw_similar_pairs = True
        except Exception as exc:
            self.Error.optimization_error(str(exc))
            self.__set_update_loop(None)
            self.__draw_similar_pairs = True
        else:
            self.progressBarSet(100.0 * progress, processEvents=None)
            self.embedding = embedding
            self._update_plot()
            # schedule next update
            self.__timer.start()

        self.__in_next_step = False

    def __invalidate_embedding(self):
        # reset/invalidate the MDS embedding, to the default initialization
        # (Random or PCA), restarting the optimization if necessary.
        if self.embedding is None:
            return
        state = self.__state
        if self.__update_loop is not None:
            self.__set_update_loop(None)

        X = self.effective_matrix

        if self.initialization == OWMDS.PCA:
            self.embedding = torgerson(X)
        else:
            self.embedding = np.random.rand(len(X), 2)

        self._update_plot()

        # restart the optimization if it was interrupted.
        if state == OWMDS.Running:
            self.__start()

    def __invalidate_refresh(self):
        state = self.__state

        if self.__update_loop is not None:
            self.__set_update_loop(None)

        # restart the optimization if it was interrupted.
        # TODO: decrease the max iteration count by the already
        # completed iterations count.
        if state == OWMDS.Running:
            self.__start()

    def handleNewSignals(self):
        if self._invalidated:
            self.__draw_similar_pairs = False
            self._invalidated = False
            self._initialize()
            self.start()

        if self._subset_mask is None and self.subset_data is not None and \
                self.data is not None:
            self._subset_mask = np.in1d(self.data.ids, self.subset_data.ids)

        self._update_plot(new=True)
        self.unconditional_commit()

    def _invalidate_output(self):
        self.commit()

    def _on_connected_changed(self):
        self._similar_pairs = None
        self.connect_pairs()

    def _update_plot(self, new=False):
        self._clear_plot()

        if self.embedding is not None:
            self._setup_plot(new=new)
        else:
            self.graph.new_data(None)

    def connect_pairs(self):
        if self._curve:
            self.graph.plot_widget.removeItem(self._curve)
        if not (self.connected_pairs and self.__draw_similar_pairs):
            return
        emb_x, emb_y = self.graph.get_xy_data_positions(
            self.variable_x, self.variable_y, self.graph.valid_data)
        if self._similar_pairs is None:
            # This code requires storing lower triangle of X (n x n / 2
            # doubles), n x n / 2 * 2 indices to X, n x n / 2 indices for
            # argsort result. If this becomes an issue, it can be reduced to
            # n x n argsort indices by argsorting the entire X. Then we
            # take the first n + 2 * p indices. We compute their coordinates
            # i, j in the original matrix. We keep those for which i < j.
            # n + 2 * p will suffice to exclude the diagonal (i = j). If the
            # number of those for which i < j is smaller than p, we instead
            # take i > j. Among those that remain, we take the first p.
            # Assuming that MDS can't show so many points that memory could
            # become an issue, I preferred using simpler code.
            m = self.effective_matrix
            n = len(m)
            p = min(n * (n - 1) // 2 * self.connected_pairs // 100,
                    MAX_N_PAIRS * self.connected_pairs // 20)
            indcs = np.triu_indices(n, 1)
            sorted = np.argsort(m[indcs])[:p]
            self._similar_pairs = fpairs = np.empty(2 * p, dtype=int)
            fpairs[::2] = indcs[0][sorted]
            fpairs[1::2] = indcs[1][sorted]
        emb_x_pairs = emb_x[self._similar_pairs].reshape((-1, 2))
        emb_y_pairs = emb_y[self._similar_pairs].reshape((-1, 2))

        # Filter out zero distance lines (in embedding coords).
        # Null (zero length) line causes bad rendering artifacts
        # in Qt when using the raster graphics system (see gh-issue: 1668).
        (x1, x2), (y1, y2) = (emb_x_pairs.T, emb_y_pairs.T)
        pairs_mask = ~(np.isclose(x1, x2) & np.isclose(y1, y2))
        emb_x_pairs = emb_x_pairs[pairs_mask, :]
        emb_y_pairs = emb_y_pairs[pairs_mask, :]
        self._curve = pg.PlotCurveItem(emb_x_pairs.ravel(),
                                       emb_y_pairs.ravel(),
                                       pen=pg.mkPen(0.8,
                                                    width=2,
                                                    cosmetic=True),
                                       connect="pairs",
                                       antialias=True)
        self.graph.plot_widget.addItem(self._curve)

    def _setup_plot(self, new=False):
        emb_x, emb_y = self.embedding[:, 0], self.embedding[:, 1]
        coords = np.vstack((emb_x, emb_y)).T

        data = self.data
        attributes = data.domain.attributes + (self.variable_x,
                                               self.variable_y)
        domain = Domain(attributes=attributes,
                        class_vars=data.domain.class_vars,
                        metas=data.domain.metas)
        data = Table.from_numpy(domain,
                                X=hstack((data.X, coords)),
                                Y=data.Y,
                                metas=data.metas)
        subset_data = data[
            self._subset_mask] if self._subset_mask is not None else None
        self.graph.new_data(data, subset_data=subset_data, new=new)
        self.graph.update_data(self.variable_x, self.variable_y, True)
        self.connect_pairs()

    def commit(self):
        if self.embedding is not None:
            names = get_unique_names(
                [v.name for v in self.data.domain.variables],
                ["mds-x", "mds-y"])
            output = embedding = Orange.data.Table.from_numpy(
                Orange.data.Domain([
                    ContinuousVariable(names[0]),
                    ContinuousVariable(names[1])
                ]), self.embedding)
        else:
            output = embedding = None

        if self.embedding is not None and self.data is not None:
            domain = self.data.domain
            domain = Orange.data.Domain(
                domain.attributes, domain.class_vars,
                domain.metas + embedding.domain.attributes)
            output = self.data.transform(domain)
            output.metas[:, -2:] = embedding.X

        selection = self.graph.get_selection()
        if output is not None and len(selection) > 0:
            selected = output[selection]
        else:
            selected = None
        if self.graph.selection is not None and np.max(
                self.graph.selection) > 1:
            annotated = create_groups_table(output, self.graph.selection)
        else:
            annotated = create_annotated_table(output, selection)
        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._clear_plot()
        self._clear()

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_graph = {}
            for old, new in (("label_only_selected", "label_only_selected"),
                             ("symbol_opacity", "alpha_value"),
                             ("symbol_size", "point_width"), ("jitter",
                                                              "jitter_size")):
                settings_graph[new] = settings_[old]
            settings_["graph"] = settings_graph
            settings_["auto_commit"] = settings_["autocommit"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            n_domain = [t for t in context.ordered_domain if t[1] == 2]
            c_domain = [t for t in context.ordered_domain if t[1] == 1]
            context_values_graph = {}
            for _, old_val, new_val in ((domain, "color_value", "attr_color"),
                                        (c_domain, "shape_value",
                                         "attr_shape"),
                                        (n_domain, "size_value", "attr_size"),
                                        (domain, "label_value", "attr_label")):
                tmp = context.values[old_val]
                if tmp[1] >= 0:
                    context_values_graph[new_val] = (tmp[0], tmp[1] + 100)
                elif tmp[0] != "Stress":
                    context_values_graph[new_val] = None
                else:
                    context_values_graph[new_val] = tmp
            context.values["graph"] = context_values_graph
Ejemplo n.º 25
0
class OWNxExplorer(OWDataProjectionWidget):
    name = "Network Explorer"
    description = "Visually explore the network and its properties."
    icon = "icons/NetworkExplorer.svg"
    priority = 6420

    class Inputs:
        node_data = Input("Node Data", Table)
        node_subset = Input("Node Subset", Table)
        network = Input("Network", Network, default=True)
        node_distances = Input("Node Distances", Orange.misc.DistMatrix)

    class Outputs(OWDataProjectionWidget.Outputs):
        subgraph = Output("Selected sub-network", Network)
        unselected_subgraph = Output("Remaining sub-network", Network)
        distances = Output("Distance matrix", Orange.misc.DistMatrix)

    UserAdviceMessages = [
        widget.Message("Double clicks select connected components",
                       widget.Message.Information),
    ]

    GRAPH_CLASS = GraphView
    graph = SettingProvider(GraphView)

    layout_density = Setting(10)
    observe_weights = Setting(True)

    mark_hops = Setting(1)
    mark_min_conn = Setting(5)
    mark_max_conn = Setting(5)
    mark_most_conn = Setting(1)

    alpha_value = 255  # Override the setting from parent

    class Warning(OWDataProjectionWidget.Warning):
        distance_matrix_mismatch = widget.Msg(
            "Distance matrix size doesn't match the number of network nodes "
            "and will be ignored.")
        no_graph_found = widget.Msg("Node data is given, graph data is missing.")

    class Error(OWDataProjectionWidget.Error):
        data_size_mismatch = widget.Msg(
            "Length of the data does not match the number of nodes.")
        network_too_large = widget.Msg("Network is too large to visualize.")
        single_node_graph = widget.Msg("I don't do single-node graphs today.")

    def __init__(self):
        # These are already needed in super().__init__()
        self.number_of_nodes = 0
        self.number_of_edges = 0
        self.nHighlighted = 0
        self.nSelected = 0
        self.nodes_per_edge = 0
        self.edges_per_node = 0

        self.mark_mode = 0
        self.mark_text = ""

        super().__init__()

        self.network = None
        self.node_data = None
        self.distance_matrix = None
        self.edges = None
        self.positions = None

        self._optimizer = None
        self._animation_thread = None
        self._stop_optimization = False

        self.marked_nodes = None
        self.searchStringTimer = QTimer(self)
        self.searchStringTimer.timeout.connect(self.update_marks)
        self.set_mark_mode()
        self.setMinimumWidth(600)

    def sizeHint(self):
        return QSize(800, 600)

    def _add_controls(self):
        self.gui = OWPlotGUI(self)
        self._add_info_box()
        self.gui.point_properties_box(self.controlArea)
        self._add_effects_box()
        self.gui.plot_properties_box(self.controlArea)
        self._add_mark_box()
        self.controls.attr_label.activated.connect(self.on_change_label_attr)

    def _add_info_box(self):
        info = gui.vBox(self.controlArea, box="Layout")
        gui.label(
            info, self,
            "Nodes: %(number_of_nodes)i (%(nodes_per_edge).2f per edge); "
            "%(nSelected)i selected")
        gui.label(
            info, self,
            "Edges: %(number_of_edges)i (%(edges_per_node).2f per node)")
        lbox = gui.hBox(info)
        self.relayout_button = gui.button(
            lbox, self, 'Improve', callback=self.improve, autoDefault=False,
            tooltip="Optimize the current layout, with a small initial jerk")
        self.stop_button = gui.button(
            lbox, self, 'Stop', callback=self.stop_relayout, autoDefault=False,
            hidden=True)
        self.randomize_button = gui.button(
            lbox, self, 'Re-layout', callback=self.restart, autoDefault=False,
            tooltip="Restart laying out from random positions")
        gui.hSlider(info, self, "layout_density", minValue=1, maxValue=50,
                    label="Gravity", orientation=Qt.Horizontal,
                    callback_finished=self.improve,
                    tooltip="Lower values improve optimization,\n"
                            "higher work better for graph with many small "
                            "components")
        gui.checkBox(info, self, "observe_weights",
                     label="Make edges with large weights shorter",
                     callback=self.improve)

    def _add_effects_box(self):
        gbox = self.gui.create_gridbox(self.controlArea, box="Widths and Sizes")
        self.gui.add_widget(self.gui.PointSize, gbox)
        gbox.layout().itemAtPosition(1, 0).widget().setText("Node Size:")
        self.gui.add_control(
            gbox, gui.hSlider, "Edge width:",
            master=self, value='graph.edge_width',
            minValue=1, maxValue=10, step=1,
            callback=self.graph.update_edges)
        box = gui.vBox(None)
        gbox.layout().addWidget(box, 3, 0, 1, 2)
        gui.separator(box)
        self.checkbox_relative_edges = gui.checkBox(
            box, self, 'graph.relative_edge_widths',
            'Scale edge widths to weights',
            callback=self.graph.update_edges)
        self.checkbox_show_weights = gui.checkBox(
            box, self, 'graph.show_edge_weights',
            'Show edge weights',
            callback=self.graph.update_edge_labels)
        self.checkbox_show_weights = gui.checkBox(
            box, self, 'graph.label_selected_edges',
            'Label only edges of selected nodes',
            callback=self.graph.update_edge_labels)

        # This is ugly: create a slider that controls alpha_value so that
        # parent can enable and disable it - although it's never added to any
        # layout and visible to the user
        gui.hSlider(None, self, "graph.alpha_value")

    def _add_mark_box(self):
        hbox = gui.hBox(None, box=True)
        self.mainArea.layout().addWidget(hbox)
        vbox = gui.hBox(hbox)

        def spin(value, label, minv, maxv):
            return gui.spin(
                vbox, self, value, label=label, minv=minv, maxv=maxv,
                step=1,
                alignment=Qt.AlignRight, callback=self.update_marks).box

        def text_line():
            def set_search_string_timer():
                self.searchStringTimer.stop()
                self.searchStringTimer.start(300)

            return gui.lineEdit(
                gui.hBox(vbox), self, "mark_text", label="Text: ",
                orientation=Qt.Horizontal, minimumWidth=50,
                callback=set_search_string_timer, callbackOnType=True).box

        def _mark_by_labels(marker):
            txt = self.mark_text.lower()
            if not txt:
                return None
            labels = self.get_label_data()
            if labels is None:
                return None
            return marker(np.char.array(labels), txt)

        def mark_label_starts():
            return _mark_by_labels(
                lambda labels, txt: np.flatnonzero(labels.lower().startswith(txt)))

        def mark_label_contains():
            return _mark_by_labels(
                lambda labels, txt: np.flatnonzero(labels.lower().find(txt) != -1))

        def mark_text():
            txt = self.mark_text.lower()
            if not txt or self.data is None:
                return None
            return np.array(
                [i for i, inst in enumerate(self.data)
                 if txt in "\x00".join(map(str, inst.list)).lower()])

        def mark_reachable():
            selected = self.graph.get_selection()
            if selected is None:
                return None
            return self.get_reachable(selected)

        def mark_close():
            selected = self.graph.get_selection()
            if selected is None:
                return None
            neighbours = set(selected)
            last_round = list(neighbours)
            for _ in range(self.mark_hops):
                next_round = set()
                for neigh in last_round:
                    next_round |= set(self.network.neighbours(neigh))
                neighbours |= next_round
                last_round = next_round
            neighbours -= set(selected)
            return np.array(list(neighbours))

        def mark_from_input():
            if self.subset_data is None or self.data is None:
                return None
            ids = set(self.subset_data.ids)
            return np.array(
                [i for i, ex in enumerate(self.data) if ex.id in ids])

        def mark_most_connections():
            n = self.mark_most_conn
            if n >= self.number_of_nodes:
                return np.arange(self.number_of_nodes)
            degrees = self.network.degrees()
            # pylint: disable=invalid-unary-operand-type
            min_degree = np.partition(degrees, -n)[-n]
            return np.flatnonzero(degrees >= min_degree)

        def mark_more_than_any_neighbour():
            degrees = self.network.degrees()
            return np.array(
                [node for node, degree in enumerate(degrees)
                 if degree > np.max(degrees[self.network.neighbours(node)],
                                    initial=0)])

        def mark_more_than_average_neighbour():
            degrees = self.network.degrees()
            return np.array(
                [node for node, degree, neighbours in (
                    (node, degree, self.network.neighbours(node))
                     for node, degree in enumerate(degrees))
                 if degree > (np.mean(degrees[neighbours]) if neighbours.size else 0)
                 ]
            )

        self.mark_criteria = [
            ("(Select criteria for marking)", None, lambda: np.zeros((0,))),
            ("Mark nodes whose label starts with", text_line(), mark_label_starts),
            ("Mark nodes whose label contains", text_line(), mark_label_contains),
            ("Mark nodes whose data that contains", text_line(), mark_text),
            ("Mark nodes reachable from selected", None, mark_reachable),

            ("Mark nodes in vicinity of selection",
             spin("mark_hops", "Number of hops:", 1, 20),
             mark_close),

            ("Mark nodes from subset signal", None, mark_from_input),

            ("Mark nodes with few connections",
             spin("mark_max_conn", "Max. connections:", 0, 1000),
             lambda: np.flatnonzero(self.network.degrees() <= self.mark_max_conn)),

            ("Mark nodes with many connections",
             spin("mark_min_conn", "Min. connections:", 1, 1000),
             lambda: np.flatnonzero(self.network.degrees() >= self.mark_min_conn)),

            ("Mark nodes with most connections",
             spin("mark_most_conn", "Number of marked:", 1, 1000),
             mark_most_connections),

            ("Mark nodes with more connections than any neighbour", None,
             mark_more_than_any_neighbour),

            ("Mark nodes with more connections than average neighbour", None,
             mark_more_than_average_neighbour)
        ]

        cb = gui.comboBox(
            hbox, self, "mark_mode",
            items=[item for item, *_ in self.mark_criteria],
            maximumContentsLength=-1, callback=self.set_mark_mode)
        hbox.layout().insertWidget(0, cb)

        gui.rubber(hbox)
        self.btselect = gui.button(
            hbox, self, "Select", callback=self.select_marked)
        self.btadd = gui.button(
            hbox, self, "Add to Selection", callback=self.select_add_marked)
        self.btgroup = gui.button(
            hbox, self, "Add New Group", callback=self.select_as_group)

    def set_mark_mode(self, mode=None):
        if mode is not None:
            self.mark_mode = mode
        for i, (_, widget, _) in enumerate(self.mark_criteria):
            if widget:
                if i == self.mark_mode:
                    widget.show()
                else:
                    widget.hide()
        self.searchStringTimer.stop()
        self.update_marks()

    def update_marks(self):
        if self.network is None:
            return
        self.marked_nodes = self.mark_criteria[self.mark_mode][2]()
        if self.marked_nodes is not None and not self.marked_nodes.size:
            self.marked_nodes = None
        self.graph.update_marks()
        if self.graph.label_only_selected:
            self.graph.update_labels()
        self.update_selection_buttons()

    def update_selection_buttons(self):
        if self.marked_nodes is None:
            self.btselect.hide()
            self.btadd.hide()
            self.btgroup.hide()
            return
        else:
            self.btselect.show()

        selection = self.graph.get_selection()
        if not len(selection) or np.max(selection) == 0:
            self.btadd.hide()
            self.btgroup.hide()
        elif np.max(selection) == 1:
            self.btadd.setText("Add to Selection")
            self.btadd.show()
            self.btgroup.hide()
        else:
            self.btadd.setText("Add to Group")
            self.btadd.show()
            self.btgroup.show()

    def selection_changed(self):
        super().selection_changed()
        self.nSelected = 0 if self.selection is None else len(self.selection)
        self.update_selection_buttons()
        self.update_marks()

    def select_marked(self):
        self.graph.selection_select(self.marked_nodes)

    def select_add_marked(self):
        self.graph.selection_append(self.marked_nodes)

    def select_as_group(self):
        self.graph.selection_new_group(self.marked_nodes)

    def on_change_label_attr(self):
        if self.mark_mode in (1, 2):
            self.update_marks()

    @Inputs.node_data
    def set_node_data(self, data):
        self.node_data = data

    @Inputs.node_subset
    def set_node_subset(self, data):
        # It would be better to call super, but this fails because super
        # is decorated to set the partial summary for signal "Subset Data",
        # which does not exist for this widget (OWNxExplorer.Inputs is not
        # derived from OWDataProjectionWidget.Inputs in order to rename the
        # signal)
        self.subset_data = data

    @Inputs.node_distances
    def set_items_distance_matrix(self, matrix):
        self.distance_matrix = matrix
        self.positions = None

    @Inputs.network
    def set_graph(self, graph):

        def set_graph_none(error=None):
            if error is not None:
                error()
            self.network = None
            self.number_of_nodes = self.edges_per_node = 0
            self.number_of_edges = self.nodes_per_edge = 0

        def compute_stats():
            self.number_of_nodes = graph.number_of_nodes()
            self.number_of_edges = graph.number_of_edges()
            self.edges_per_node = self.number_of_edges / self.number_of_nodes
            self.nodes_per_edge = \
                self.number_of_nodes / max(1, self.number_of_edges)

        self.mark_text = ""
        self.set_mark_mode(0)
        self.positions = None

        if not graph or graph.number_of_nodes() == 0:
            set_graph_none()
            return
        if graph.number_of_nodes() + graph.number_of_edges() > 100000:
            set_graph_none(self.Error.network_too_large)
            return
        self.Error.clear()

        self.network = graph
        compute_stats()

    def handleNewSignals(self):
        network = self.network

        def set_actual_data():
            self.closeContext()
            self.Error.data_size_mismatch.clear()
            self.Warning.no_graph_found.clear()
            self._invalid_data = False
            if network is None:
                if self.node_data is not None:
                    self.Warning.no_graph_found()
                return
            n_nodes = len(self.network.nodes)
            if self.node_data is not None:
                if len(self.node_data) != n_nodes:
                    self.Error.data_size_mismatch()
                    self._invalid_data = True
                    self.data = None
                else:
                    self.data = self.node_data
            if self.node_data is None:
                if isinstance(network.nodes, Table):
                    self.data = network.nodes
                elif isinstance(network.nodes, np.ndarray) \
                        and (len(network.nodes.shape) == 1
                             or network.nodes.shape[1] == 1):
                    self.data = Table.from_numpy(
                        Domain([], None, [StringVariable("label")]),
                        np.zeros((len(network.nodes),0)),
                        None,
                        metas=network.nodes.reshape((n_nodes, 1))
                    )
                else:
                    self.data = None

            if self.data is not None:
                # Replicate the necessary parts of set_data
                self.valid_data = np.full(len(self.data), True, dtype=bool)
                self.init_attr_values()
                self.openContext(self.data)
                self.cb_class_density.setEnabled(self.can_draw_density())

        def set_actual_edges():
            def set_checkboxes(value):
                self.checkbox_show_weights.setEnabled(value)
                self.checkbox_relative_edges.setEnabled(value)

            self.Warning.distance_matrix_mismatch.clear()

            if self.network is None:
                self.edges = None
                set_checkboxes(False)
                return

            set_checkboxes(True)
            if network.number_of_edges(0):
                self.edges = network.edges[0].edges.tocoo()
            else:
                self.edges = sp.coo_matrix((0, 3))
            if self.distance_matrix is not None:
                if len(self.distance_matrix) != self.number_of_nodes:
                    self.Warning.distance_matrix_mismatch()
                else:
                    self.edges.data = np.fromiter(
                        (self.distance_matrix[u, v]
                         for u, v in zip(self.edges.row, self.edges.col)),
                        dtype=np.int32, count=len(self.edges.row)
                    )
            if np.allclose(self.edges.data, 0):
                self.edges.data[:] = 1
                set_checkboxes(False)
            elif len(set(self.edges.data)) == 1:
                set_checkboxes(False)

        self.stop_optimization_and_wait()
        set_actual_data()
        super()._handle_subset_data()
        if self.positions is None:
            set_actual_edges()
            self.set_random_positions()
            self.graph.reset_graph()
            self.relayout(True)
        else:
            self.graph.update_point_props()
        self.update_marks()
        self.update_selection_buttons()

    def init_attr_values(self):
        super().init_attr_values()
        if self.node_data is None \
                and self.data is not None \
                and isinstance(self.network.nodes, np.ndarray):
            assert len(self.data.domain.metas) == 1
            self.attr_label = self.data.domain.metas[0]

    def randomize(self):
        self.set_random_positions()
        self.graph.update_coordinates()

    def set_random_positions(self):
        if self.network is None:
            self.position = None
        else:
            self.positions = np.random.uniform(size=(self.number_of_nodes, 2))

    def get_reachable(self, initial):
        to_check = list(initial)
        reachable = set(to_check)
        for node in to_check:
            new_checks = set(self.network.neighbours(node)) - reachable
            to_check += new_checks
            reachable |= new_checks
        return np.array(to_check)

    def send_data(self):
        super().send_data()

        Outputs = self.Outputs
        selected_indices = self.graph.get_selection()
        if selected_indices is None or len(selected_indices) == 0:
            Outputs.subgraph.send(None)
            Outputs.unselected_subgraph.send(self.network)
            Outputs.distances.send(None)
            return

        selection = self.graph.selection
        subgraph = self.network.subgraph(selected_indices)
        subgraph.nodes = \
            self._get_selected_data(self.data, selected_indices, selection)
        Outputs.subgraph.send(subgraph)
        Outputs.unselected_subgraph.send(
            self.network.subgraph(np.flatnonzero(selection == 0)))
        distances = self.distance_matrix
        if distances is None:
            Outputs.distances.send(None)
        else:
            Outputs.distances.send(distances.submatrix(sorted(selected_indices)))

    def get_coordinates_data(self):
        if self.positions is not None:
            return self.positions.T
        else:
            return None, None

    def get_embedding(self):
        return self.positions

    def get_subset_mask(self):
        if self.data is None:
            return None
        return super().get_subset_mask()

    def get_edges(self):
        return self.edges

    def is_directed(self):
        return self.network is not None and self.network.edges[0].directed

    def get_marked_nodes(self):
        return self.marked_nodes

    def set_buttons(self, running):
        self.stop_button.setHidden(not running)
        self.relayout_button.setHidden(running)
        self.randomize_button.setHidden(running)

    def stop_relayout(self):
        self._stop_optimization = True
        self.set_buttons(running=False)

    def restart(self):
        self.relayout(restart=True)

    def improve(self):
        self.relayout(restart=False)

    # TODO: Stop relayout if new data is received
    def relayout(self, restart):
        if self.edges is None:
            return
        if restart or self.positions is None:
            self.set_random_positions()
        self.progressbar = gui.ProgressBar(self, 100)
        self.set_buttons(running=True)
        self._stop_optimization = False

        Simplifications = self.graph.Simplifications
        self.graph.set_simplifications(
            Simplifications.NoDensity
            + Simplifications.NoLabels * (len(self.graph.labels) > 20)
            + Simplifications.NoEdgeLabels * (len(self.graph.edge_labels) > 20)
            + Simplifications.NoEdges * (self.number_of_edges > 30000))

        large_graph = self.number_of_nodes + self.number_of_edges > 30000

        class LayoutOptimizer(QObject):
            update = Signal(np.ndarray, float)
            done = Signal(np.ndarray)
            stopped = Signal()

            def __init__(self, widget):
                super().__init__()
                self.widget = widget

            def send_update(self, positions, progress):
                if not large_graph:
                    self.update.emit(np.array(positions), progress)
                return not self.widget._stop_optimization

            def run(self):
                widget = self.widget
                edges = widget.edges
                nnodes = widget.number_of_nodes
                init_temp = 0.05 if restart else 0.2
                k = widget.layout_density / 10 / np.sqrt(nnodes)
                sample_ratio =  None if nnodes < 1000 else 1000 / nnodes
                fruchterman_reingold(
                    widget.positions, edges, widget.observe_weights,
                    FR_ALLOWED_TIME, k, init_temp, sample_ratio,
                    callback_step=4, callback=self.send_update)
                self.done.emit(widget.positions)
                self.stopped.emit()

        def update(positions, progress):
            self.progressbar.advance(progress)
            self.positions = positions
            self.graph.update_coordinates()

        def done(positions):
            self.positions = positions
            self.set_buttons(running=False)
            self.graph.set_simplifications(
                self.graph.Simplifications.NoSimplifications)
            self.graph.update_coordinates()
            self.progressbar.finish()

        def thread_finished():
            self._optimizer = None
            self._animation_thread = None

        self._optimizer = LayoutOptimizer(self)
        self._animation_thread = QThread()
        self._optimizer.update.connect(update)
        self._optimizer.done.connect(done)
        self._optimizer.stopped.connect(self._animation_thread.quit)
        self._optimizer.moveToThread(self._animation_thread)
        self._animation_thread.started.connect(self._optimizer.run)
        self._animation_thread.finished.connect(thread_finished)
        self._animation_thread.start()

    def stop_optimization_and_wait(self):
        if self._animation_thread is not None:
            self._stop_optimization = True
            self._animation_thread.quit()
            self._animation_thread.wait()
            self._animation_thread = None

    def onDeleteWidget(self):
        self.stop_optimization_and_wait()
        super().onDeleteWidget()

    def send_report(self):
        if self.network is None:
            return

        self.report_items('Graph info', [
            ("Number of vertices", self.network.number_of_nodes()),
            ("Number of edges", self.network.number_of_edges()),
            ("Vertices per edge", round(self.nodes_per_edge, 3)),
            ("Edges per vertex", round(self.edges_per_node, 3)),
        ])
        if self.data is not None:
            self.report_data("Data", self.data)
        if any((self.attr_color, self.attr_shape,
                self.attr_size, self.attr_label)):
            self.report_items(
                "Visual settings",
                [("Color", self._get_caption_var_name(self.attr_color)),
                 ("Label", self._get_caption_var_name(self.attr_label)),
                 ("Shape", self._get_caption_var_name(self.attr_shape)),
                 ("Size", self._get_caption_var_name(self.attr_size))])
        self.report_plot()
Ejemplo n.º 26
0
class OWtSNE(OWDataProjectionWidget):
    name = "t-SNE"
    description = "Two-dimensional data projection with t-SNE."
    icon = "icons/TSNE.svg"
    priority = 920
    keywords = ["tsne"]

    settings_version = 3
    max_iter = Setting(300)
    perplexity = Setting(30)
    pca_components = Setting(20)

    GRAPH_CLASS = OWtSNEGraph
    graph = SettingProvider(OWtSNEGraph)
    embedding_variables_names = ("t-SNE-x", "t-SNE-y")

    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    class Outputs(OWDataProjectionWidget.Outputs):
        preprocessor = Output("Preprocessor", Preprocess)

    class Error(OWDataProjectionWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        constant_data = Msg("Input data is constant")
        no_attributes = Msg("Data has no attributes")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")
        no_valid_data = Msg("No projection due to no valid data")

    def __init__(self):
        super().__init__()
        self.pca_data = None
        self.projection = None
        self.__update_loop = None
        # timer for scheduling updates
        self.__timer = QTimer(self, singleShot=True, interval=1,
                              timeout=self.__next_step)
        self.__state = OWtSNE.Waiting
        self.__in_next_step = False
        self.__draw_similar_pairs = False

    def _add_controls(self):
        self._add_controls_start_box()
        super()._add_controls()
        # Because sc data frequently has many genes,
        # showing all attributes in combo boxes can cause problems
        # QUICKFIX: Remove a separator and attributes from order
        # (leaving just the class and metas)
        self.models = self.gui.points_models
        for model in self.models:
            model.order = model.order[:-2]

    def _add_controls_start_box(self):
        box = gui.vBox(self.controlArea, True)
        form = QFormLayout(
            labelAlignment=Qt.AlignLeft,
            formAlignment=Qt.AlignLeft,
            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
            verticalSpacing=10
        )

        form.addRow(
            "Max iterations:",
            gui.spin(box, self, "max_iter", 1, 2000, step=50))

        form.addRow(
            "Perplexity:",
            gui.spin(box, self, "perplexity", 1, 100, step=1))

        box.layout().addLayout(form)

        gui.separator(box, 10)
        self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run)

        gui.separator(box, 10)
        gui.hSlider(box, self, "pca_components", label="PCA components:",
                    minValue=2, maxValue=50, step=1)

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            if len(self.data) < 2:
                error(self.Error.not_enough_rows)
            elif not self.data.domain.attributes:
                error(self.Error.no_attributes)
            elif not self.data.is_sparse() and \
                    np.allclose(self.data.X - self.data.X[0], 0):
                error(self.Error.constant_data)
            elif not self.data.is_sparse() and \
                    np.all(~np.isfinite(self.data.X)):
                error(self.Error.no_valid_data)

    def get_embedding(self):
        if self.data is None:
            self.valid_data = None
            return None
        elif self.projection is None:
            embedding = np.random.normal(size=(len(self.data), 2))
        else:
            embedding = self.projection.embedding.X
        self.valid_data = np.ones(len(embedding), dtype=bool)
        return embedding

    def _toggle_run(self):
        if self.__state == OWtSNE.Running:
            self.stop()
            self.commit()
        else:
            self.start()

    def start(self):
        if not self.data or self.__state == OWtSNE.Running:
            self.graph.update_coordinates()
        elif self.__state in (OWtSNE.Finished, OWtSNE.Waiting):
            self.__start()

    def stop(self):
        if self.__state == OWtSNE.Running:
            self.__set_update_loop(None)

    def pca_preprocessing(self):
        if self.pca_data is not None and \
                self.pca_data.X.shape[1] == self.pca_components:
            return
        pca = PCA(n_components=self.pca_components, random_state=0)
        model = pca(self.data)
        self.pca_data = model(self.data)

    def __start(self):
        self.pca_preprocessing()
        initial = 'random' if self.projection is None \
            else self.projection.embedding.X
        step_size = 50

        def update_loop(data, max_iter, step, embedding):
            # NOTE: this code MUST NOT call into QApplication.processEvents
            done = False
            iterations_done = 0

            while not done:
                step_iter = min(max_iter - iterations_done, step)
                projection = compute_tsne(
                    data, self.perplexity, step_iter, embedding)
                embedding = projection.embedding.X
                iterations_done += step_iter
                if iterations_done >= max_iter:
                    done = True

                yield projection, iterations_done / max_iter

        self.__set_update_loop(update_loop(
            self.pca_data, self.max_iter, step_size, initial))
        self.progressBarInit(processEvents=None)

    def __set_update_loop(self, loop):
        if self.__update_loop is not None:
            self.__update_loop.close()
            self.__update_loop = None
            self.progressBarFinished(processEvents=None)

        self.__update_loop = loop

        if loop is not None:
            self.setBlocking(True)
            self.progressBarInit(processEvents=None)
            self.setStatusMessage("Running")
            self.runbutton.setText("Stop")
            self.__state = OWtSNE.Running
            self.__timer.start()
        else:
            self.setBlocking(False)
            self.setStatusMessage("")
            self.runbutton.setText("Start")
            self.__state = OWtSNE.Finished
            self.__timer.stop()

    def __next_step(self):
        if self.__update_loop is None:
            return

        assert not self.__in_next_step
        self.__in_next_step = True

        loop = self.__update_loop
        self.Error.out_of_memory.clear()
        self.Error.optimization_error.clear()
        try:
            projection, progress = next(self.__update_loop)
            assert self.__update_loop is loop
        except StopIteration:
            self.__set_update_loop(None)
            self.unconditional_commit()
        except MemoryError:
            self.Error.out_of_memory()
            self.__set_update_loop(None)
        except Exception as exc:
            self.Error.optimization_error(str(exc))
            self.__set_update_loop(None)
        else:
            self.progressBarSet(100.0 * progress, processEvents=None)
            self.projection = projection
            self.graph.update_coordinates()
            self.graph.update_density()
            # schedule next update
            self.__timer.start()

        self.__in_next_step = False

    def setup_plot(self):
        super().setup_plot()
        self.start()

    def commit(self):
        super().commit()
        self.send_preprocessor()

    def _get_projection_data(self):
        if self.data is None or self.projection is None:
            return None
        data = self.data.transform(
            Domain(self.data.domain.attributes,
                   self.data.domain.class_vars,
                   self.data.domain.metas + self.projection.domain.attributes))
        data.metas[:, -2:] = self.get_embedding()
        return data

    def send_preprocessor(self):
        prep = None
        if self.data is not None and self.projection is not None:
            prep = ApplyDomain(self.projection.domain, self.projection.name)
        self.Outputs.preprocessor.send(prep)

    def clear(self):
        super().clear()
        self.__set_update_loop(None)
        self.__state = OWtSNE.Waiting
        self.pca_data = None
        self.projection = None

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 3:
            if "selection_indices" in settings:
                settings["selection"] = settings["selection_indices"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Ejemplo n.º 27
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read data from an input file or network " \
                  "and send a data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["file", "load", "read", "open"]

    class Outputs:
        data = Output("Data",
                      Table,
                      doc="Attribute-valued dataset read from the input file.")

    want_main_area = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]
    SIZE_LIMIT = 1e7
    LOCAL_FILE, URL = range(2)

    settingsHandler = PerfectDomainContextHandler(
        match_values=PerfectDomainContextHandler.MATCH_VALUES_ALL)

    # pylint seems to want declarations separated from definitions
    recent_paths: List[RecentPath]
    recent_urls: List[str]
    variables: list

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
        RecentPath("", "sample-datasets", "heart_disease.tab"),
        RecentPath("", "sample-datasets", "brown-selected.tab"),
        RecentPath("", "sample-datasets", "zoo.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    variables = ContextSetting([])

    domain_editor = SettingProvider(DomainEditor)

    class Warning(widget.OWWidget.Warning):
        file_too_big = widget.Msg(
            "The file is too large to load automatically."
            " Press Reload to load.")
        load_warning = widget.Msg("Read warning:\n{}")

    class Error(widget.OWWidget.Error):
        file_not_found = widget.Msg("File not found.")
        missing_reader = widget.Msg("Missing reader.")
        sheet_error = widget.Msg("Error listing available sheets.")
        unknown = widget.Msg("Read error:\n{}")

    class NoFileSelected:
        pass

    UserAdviceMessages = [
        widget.Message(
            "Use CSV File Import widget for advanced options "
            "for comma-separated files", "use-csv-file-import"),
        widget.Message(
            "This widget loads only tabular data. Use other widgets to load "
            "other data types like models, distance matrices and networks.",
            "other-data-types")
    ]

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self.reader = None

        layout = QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data,
                                addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File:", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(None,
                                 self,
                                 '...',
                                 callback=self.browse_file,
                                 autoDefault=False)
        file_button.setIcon(self.style().standardIcon(QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(None,
                                   self,
                                   "Reload",
                                   callback=self.load_data,
                                   autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(
            None,
            self,
            "xls_sheet",
            callback=self.select_sheet,
            sendSelectedValue=True,
        )
        self.sheet_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(self.sheet_label, Qt.AlignLeft)
        self.sheet_box.layout().addWidget(self.sheet_combo, Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL:", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, Qt.AlignVCenter)

        self.url_combo = url_combo = QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setLineEdit(LineEditSelectOnFocus())
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.Ignored, Policy.Fixed)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)
        # whit completer we set that combo box is case sensitive when
        # matching the history
        completer = QCompleter()
        completer.setCaseSensitivity(Qt.CaseSensitive)
        url_combo.setCompleter(completer)

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, "Columns (Double click to edit)")
        self.domain_editor = DomainEditor(self)
        self.editor_model = self.domain_editor.model()
        box.layout().addWidget(self.domain_editor)

        box = gui.hBox(self.controlArea)
        gui.button(box,
                   self,
                   "Browse documentation datasets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)

        gui.button(box, self, "Reset", callback=self.reset_domain_edit)
        self.apply_button = gui.button(box,
                                       self,
                                       "Apply",
                                       callback=self.apply_domain_edit)
        self.apply_button.setEnabled(False)
        self.apply_button.setFixedWidth(170)
        self.editor_model.dataChanged.connect(
            lambda: self.apply_button.setEnabled(True))

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)

        self.setAcceptDrops(True)

        if self.source == self.LOCAL_FILE:
            last_path = self.last_path()
            if last_path and os.path.exists(last_path) and \
                    os.path.getsize(last_path) > self.SIZE_LIMIT:
                self.Warning.file_too_big()
                return

        QTimer.singleShot(0, self.load_data)

    @staticmethod
    def sizeHint():
        return QSize(600, 550)

    def select_file(self, n):
        assert n < len(self.recent_paths)
        super().select_file(n)
        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def select_sheet(self):
        self.recent_paths[0].sheet = self.sheet_combo.currentText()
        self.load_data()

    def _url_set(self):
        url = self.url_combo.currentText()
        pos = self.recent_urls.index(url)
        url = url.strip()

        if not urlparse(url).scheme:
            url = 'http://' + url
            self.url_combo.setItemText(pos, url)
            self.recent_urls[pos] = url

        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation datasets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        readers = [
            f for f in FileFormat.formats
            if getattr(f, 'read', None) and getattr(f, "EXTENSIONS", None)
        ]
        filename, reader, _ = open_filename_dialog(start_file, None, readers)
        if not filename:
            return
        self.add_path(filename)
        if reader is not None:
            self.recent_paths[0].file_format = reader.qualified_name()

        self.source = self.LOCAL_FILE
        self.load_data()

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        # We need to catch any exception type since anything can happen in
        # file readers
        self.closeContext()
        self.domain_editor.set_domain(None)
        self.apply_button.setEnabled(False)
        self.clear_messages()
        self.set_file_list()

        error = self._try_load()
        if error:
            error()
            self.data = None
            self.sheet_box.hide()
            self.Outputs.data.send(None)
            self.infolabel.setText("No data.")

    def _try_load(self):
        # pylint: disable=broad-except
        if self.last_path() and not os.path.exists(self.last_path()):
            return self.Error.file_not_found

        try:
            self.reader = self._get_reader()
            assert self.reader is not None
        except Exception:
            return self.Error.missing_reader

        if self.reader is self.NoFileSelected:
            self.Outputs.data.send(None)
            return None

        try:
            self._update_sheet_combo()
        except Exception:
            return self.Error.sheet_error

        with catch_warnings(record=True) as warnings:
            try:
                data = self.reader.read()
            except Exception as ex:
                log.exception(ex)
                return lambda x=ex: self.Error.unknown(str(x))
            if warnings:
                self.Warning.load_warning(warnings[-1].message.args[0])

        self.infolabel.setText(self._describe(data))

        self.loaded_file = self.last_path()
        add_origin(data, self.loaded_file)
        self.data = data
        self.openContext(data.domain)
        self.apply_domain_edit()  # sends data
        return None

    def _get_reader(self) -> FileFormat:
        if self.source == self.LOCAL_FILE:
            path = self.last_path()
            if path is None:
                return self.NoFileSelected
            if self.recent_paths and self.recent_paths[0].file_format:
                qname = self.recent_paths[0].file_format
                reader_class = class_from_qualified_name(qname)
                reader = reader_class(path)
            else:
                reader = FileFormat.get_reader(path)
            if self.recent_paths and self.recent_paths[0].sheet:
                reader.select_sheet(self.recent_paths[0].sheet)
            return reader
        else:
            url = self.url_combo.currentText().strip()
            if url:
                return UrlReader(url)
            else:
                return self.NoFileSelected

    def _update_sheet_combo(self):
        if len(self.reader.sheets) < 2:
            self.sheet_box.hide()
            self.reader.select_sheet(None)
            return

        self.sheet_combo.clear()
        self.sheet_combo.addItems(self.reader.sheets)
        self._select_active_sheet()
        self.sheet_box.show()

    def _select_active_sheet(self):
        if self.reader.sheet:
            try:
                idx = self.reader.sheets.index(self.reader.sheet)
                self.sheet_combo.setCurrentIndex(idx)
            except ValueError:
                # Requested sheet does not exist in this file
                self.reader.select_sheet(None)
        else:
            self.sheet_combo.setCurrentIndex(0)

    @staticmethod
    def _describe(table):
        def missing_prop(prop):
            if prop:
                return f"({prop * 100:.1f}% missing values)"
            else:
                return "(no missing values)"

        domain = table.domain
        text = ""

        attrs = getattr(table, "attributes", {})
        descs = [
            attrs[desc] for desc in ("Name", "Description") if desc in attrs
        ]
        if len(descs) == 2:
            descs[0] = f"<b>{descs[0]}</b>"
        if descs:
            text += f"<p>{'<br/>'.join(descs)}</p>"

        text += f"<p>{len(table)} instance(s)"

        missing_in_attr = missing_prop(table.has_missing_attribute()
                                       and table.get_nan_frequency_attribute())
        missing_in_class = missing_prop(table.has_missing_class()
                                        and table.get_nan_frequency_class())
        text += f"<br/>{len(domain.attributes)} feature(s) {missing_in_attr}"
        if domain.has_continuous_class:
            text += f"<br/>Regression; numerical class {missing_in_class}"
        elif domain.has_discrete_class:
            text += "<br/>Classification; categorical class " \
                f"with {len(domain.class_var.values)} values {missing_in_class}"
        elif table.domain.class_vars:
            text += "<br/>Multi-target; " \
                f"{len(table.domain.class_vars)} target variables " \
                f"{missing_in_class}"
        else:
            text += "<br/>Data has no target variable."
        text += f"<br/>{len(domain.metas)} meta attribute(s)"
        text += "</p>"

        if 'Timestamp' in table.domain:
            # Google Forms uses this header to timestamp responses
            text += f"<p>First entry: {table[0, 'Timestamp']}<br/>" \
                f"Last entry: {table[-1, 'Timestamp']}</p>"
        return text

    def storeSpecificSettings(self):
        self.current_context.modified_variables = self.variables[:]

    def retrieveSpecificSettings(self):
        if hasattr(self.current_context, "modified_variables"):
            self.variables[:] = self.current_context.modified_variables

    def reset_domain_edit(self):
        self.domain_editor.reset_domain()
        self.apply_domain_edit()

    def apply_domain_edit(self):
        if self.data is None:
            table = None
        else:
            domain, cols = self.domain_editor.get_domain(
                self.data.domain, self.data)
            if not (domain.variables or domain.metas):
                table = None
            else:
                X, y, m = cols
                table = Table.from_numpy(domain, X, y, m, self.data.W)
                table.name = self.data.name
                table.ids = np.array(self.data.ids)
                table.attributes = getattr(self.data, 'attributes', {})

        self.Outputs.data.send(table)
        self.apply_button.setEnabled(False)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~" + os.path.sep + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += f" ({self.sheet_combo.currentText()})"
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    @staticmethod
    def dragEnterEvent(event):
        """Accept drops of valid file urls"""
        urls = event.mimeData().urls()
        if urls:
            try:
                FileFormat.get_reader(urls[0].toLocalFile())
                event.acceptProposedAction()
            except IOError:
                pass

    def dropEvent(self, event):
        """Handle file drops"""
        urls = event.mimeData().urls()
        if urls:
            self.add_path(urls[0].toLocalFile())  # add first file
            self.source = self.LOCAL_FILE
            self.load_data()

    def workflowEnvChanged(self, key, value, oldvalue):
        """
        Function called when environment changes (e.g. while saving the scheme)
        It make sure that all environment connected values are modified
        (e.g. relative file paths are changed)
        """
        self.update_file_list(key, value, oldvalue)
Ejemplo n.º 28
0
class OWHyper(OWWidget):
    name = "Hyperspectra"
    inputs = [("Data", Orange.data.Table, 'set_data', Default)]
    outputs = [("Selection", Orange.data.Table), ("Data", Orange.data.Table)]
    icon = "icons/hyper.svg"

    settings_version = 2
    settingsHandler = DomainContextHandler(metas_in_res=True)

    imageplot = SettingProvider(ImagePlot)
    curveplot = SettingProvider(CurvePlotHyper)

    integration_method = Setting(0)
    integration_methods = [Integrate.Simple, Integrate.Baseline,
                           Integrate.PeakMax, Integrate.PeakBaseline, Integrate.PeakAt]
    value_type = Setting(0)
    attr_value = ContextSetting(None)

    lowlim = Setting(None)
    highlim = Setting(None)
    choose = Setting(None)

    class Warning(OWWidget.Warning):
        threshold_error = Msg("Low slider should be less than High")

    class Error(OWWidget.Warning):
        image_too_big = Msg("Image for chosen features is too big ({} x {}).")

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            # delete the saved attr_value to prevent crashes
            try:
                del settings_["context_settings"][0].values["attr_value"]
            except:
                pass

    def __init__(self):
        super().__init__()

        dbox = gui.widgetBox(self.controlArea, "Image values")

        rbox = gui.radioButtons(
            dbox, self, "value_type", callback=self._change_integration)

        gui.appendRadioButton(rbox, "From spectra")

        self.box_values_spectra = gui.indentedBox(rbox)

        gui.comboBox(
            self.box_values_spectra, self, "integration_method", valueType=int,
            items=(a.name for a in self.integration_methods),
            callback=self._change_integral_type)
        gui.rubber(self.controlArea)

        gui.appendRadioButton(rbox, "Use feature")

        self.box_values_feature = gui.indentedBox(rbox)

        self.feature_value_model = DomainModel(DomainModel.METAS | DomainModel.CLASSES,
                                               valid_types=DomainModel.PRIMITIVE)
        self.feature_value = gui.comboBox(
            self.box_values_feature, self, "attr_value",
            callback=self.update_feature_value, model=self.feature_value_model,
            sendSelectedValue=True, valueType=str)

        splitter = QSplitter(self)
        splitter.setOrientation(Qt.Vertical)
        self.imageplot = ImagePlot(self, self.image_selection_changed)
        self.curveplot = CurvePlotHyper(self, select=SELECTONE)
        self.curveplot.plot.vb.x_padding = 0.005  # pad view so that lines are not hidden
        splitter.addWidget(self.imageplot)
        splitter.addWidget(self.curveplot)
        self.mainArea.layout().addWidget(splitter)

        self.line1 = MovableVlineWD(position=self.lowlim, label="", setvalfn=self.set_lowlim,
                                    confirmfn=self.edited, report=self.curveplot)
        self.line2 = MovableVlineWD(position=self.highlim, label="", setvalfn=self.set_highlim,
                                    confirmfn=self.edited, report=self.curveplot)
        self.line3 = MovableVlineWD(position=self.choose, label="", setvalfn=self.set_choose,
                                    confirmfn=self.edited, report=self.curveplot)
        self.curveplot.add_marking(self.line1)
        self.curveplot.add_marking(self.line2)
        self.curveplot.add_marking(self.line3)
        self.line1.hide()
        self.line2.hide()
        self.line3.hide()

        self.data = None

        self.resize(900, 700)
        self.graph_name = "imageplot.plotview"
        self._update_integration_type()

    def image_selection_changed(self, indices):
        annotated = create_annotated_table(self.data, indices)
        self.send("Data", annotated)
        if self.data:
            selected = self.data[indices]
            self.send("Selection", selected if selected else None)
            if selected:
                self.curveplot.set_data(selected)
            else:
                self.curveplot.set_data(self.data)
        else:
            self.send("Selection", None)
            self.curveplot.set_data(None)
        self.curveplot.update_view()

    def selection_changed(self):
        self.redraw_data()

    def init_attr_values(self):
        domain = self.data.domain if self.data is not None else None
        self.feature_value_model.set_domain(domain)
        self.attr_value = self.feature_value_model[0] if self.feature_value_model else None

    def set_lowlim(self, v):
        self.lowlim = v

    def set_highlim(self, v):
        self.highlim = v

    def set_choose(self, v):
        self.choose = v

    def redraw_data(self):
        self.imageplot.set_integral_limits()

    def update_feature_value(self):
        self.redraw_data()

    def _update_integration_type(self):
        self.line1.hide()
        self.line2.hide()
        self.line3.hide()
        if self.value_type == 0:
            self.box_values_spectra.setDisabled(False)
            self.box_values_feature.setDisabled(True)
            if self.integration_methods[self.integration_method] != Integrate.PeakAt:
                self.line1.show()
                self.line2.show()
            else:
                self.line3.show()
        elif self.value_type == 1:
            self.box_values_spectra.setDisabled(True)
            self.box_values_feature.setDisabled(False)
        QTest.qWait(1)  # first update the interface

    def _change_integration(self):
        # change what to show on the image
        self._update_integration_type()
        self.redraw_data()

    def edited(self):
        self.redraw_data()

    def _change_integral_type(self):
        self._change_integration()

    def set_data(self, data):
        self.closeContext()
        self.curveplot.set_data(data)
        if data is not None:
            same_domain = (self.data and
                           data.domain.checksum() == self.data.domain.checksum())
            self.data = data
            if not same_domain:
                self.init_attr_values()
        else:
            self.data = None
        if self.curveplot.data_x is not None and len(self.curveplot.data_x):
            minx = self.curveplot.data_x[0]
            maxx = self.curveplot.data_x[-1]

            if self.lowlim is None or not minx <= self.lowlim <= maxx:
                self.lowlim = minx
            self.line1.setValue(self.lowlim)

            if self.highlim is None or not minx <= self.highlim <= maxx:
                self.highlim = maxx
            self.line2.setValue(self.highlim)

            if self.choose is None:
                self.choose = (minx + maxx)/2
            elif self.choose < minx:
                self.choose = minx
            elif self.choose > maxx:
                self.choose = maxx
            self.line3.setValue(self.choose)

        self.imageplot.set_data(data)
        self.openContext(data)
        self.curveplot.update_view()
        self.imageplot.update_view()

    # store selection as a list due to a bug in checking if numpy settings changed
    def storeSpecificSettings(self):
        selection = self.imageplot.selection
        if selection is not None:
            selection = list(selection)
        self.current_context.selection = selection

    def retrieveSpecificSettings(self):
        selection = getattr(self.current_context, "selection", None)
        if selection is not None:
            selection = np.array(selection, dtype="bool")
        self.imageplot.selection = selection
Ejemplo n.º 29
0
class OWAnchorProjectionWidget(OWDataProjectionWidget, openclass=True):
    """ Base widget for widgets with graphs with anchors. """
    SAMPLE_SIZE = 100

    GRAPH_CLASS = OWGraphWithAnchors
    graph = SettingProvider(OWGraphWithAnchors)

    class Outputs(OWDataProjectionWidget.Outputs):
        components = Output("Components", Table)

    class Error(OWDataProjectionWidget.Error):
        sparse_data = Msg("Sparse data is not supported")
        no_valid_data = Msg("No projection due to no valid data")
        no_instances = Msg("At least two data instances are required")
        proj_error = Msg("An error occurred while projecting data.\n{}")

    def __init__(self):
        self.projector = self.projection = None
        super().__init__()
        self.graph.view_box.started.connect(self._manual_move_start)
        self.graph.view_box.moved.connect(self._manual_move)
        self.graph.view_box.finished.connect(self._manual_move_finish)

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            if self.data.is_sparse():
                error(self.Error.sparse_data)
            elif len(self.data) < 2:
                error(self.Error.no_instances)
            else:
                if not np.sum(np.all(np.isfinite(self.data.X), axis=1)):
                    error(self.Error.no_valid_data)

    def init_projection(self):
        self.projection = None
        if not self.effective_variables:
            return
        try:
            self.projection = self.projector(self.effective_data)
        except Exception as ex:  # pylint: disable=broad-except
            self.Error.proj_error(ex)

    def get_embedding(self):
        self.valid_data = None
        if self.data is None or self.projection is None:
            return None
        embedding = self.projection(self.data).X
        self.valid_data = np.all(np.isfinite(embedding), axis=1)
        return embedding

    def get_anchors(self):
        if self.projection is None:
            return None, None
        components = self.projection.components_
        if components.shape == (1, 1):
            components = np.array([[1.], [0.]])
        return components.T, [a.name for a in self.effective_variables]

    def _manual_move_start(self):
        self.graph.set_sample_size(self.SAMPLE_SIZE)

    def _manual_move(self, anchor_idx, x, y):
        self.projection.components_[:, anchor_idx] = [x, y]
        self.graph.update_coordinates()

    def _manual_move_finish(self, anchor_idx, x, y):
        self._manual_move(anchor_idx, x, y)
        self.graph.set_sample_size(None)
        self.commit()

    def _get_projection_data(self):
        if self.data is None or self.projection is None:
            return None
        proposed = [a.name for a in self.projection.domain.attributes]
        names = get_unique_names(self.data.domain, proposed)

        if proposed != names:
            attributes = tuple([
                attr.copy(name=name)
                for name, attr in zip(names, self.projection.domain.attributes)
            ])
        else:
            attributes = self.projection.domain.attributes
        return self.data.transform(
            Domain(self.data.domain.attributes, self.data.domain.class_vars,
                   self.data.domain.metas + attributes))

    def commit(self):
        super().commit()
        self.send_components()

    def send_components(self):
        components = None
        if self.data is not None and self.projection is not None:
            proposed = [var.name for var in self.effective_variables]
            comp_name = get_unique_names(proposed, 'component')
            meta_attrs = [StringVariable(name=comp_name)]
            domain = Domain(self.effective_variables, metas=meta_attrs)
            components = Table(domain,
                               self._send_components_x().copy(),
                               metas=self._send_components_metas())
            components.name = "components"
        self.Outputs.components.send(components)

    def _send_components_x(self):
        return self.projection.components_

    def _send_components_metas(self):
        variable_names = [a.name for a in self.projection.domain.attributes]
        return np.array(variable_names, dtype=object)[:, None]

    def clear(self):
        super().clear()
        self.projector = self.projection = None
Ejemplo n.º 30
0
class OWManifoldLearning(OWWidget):
    name = "Manifold Learning"
    description = "Nonlinear dimensionality reduction."
    icon = "icons/Manifold.svg"
    priority = 2200

    inputs = [("Data", Table, "set_data")]
    outputs = [("Transformed data", Table)]

    MANIFOLD_METHODS = (TSNE, MDS, Isomap, LocallyLinearEmbedding,
                        SpectralEmbedding)

    tsne_editor = SettingProvider(TSNEParametersEditor)
    mds_editor = SettingProvider(MDSParametersEditor)
    isomap_editor = SettingProvider(IsomapParametersEditor)
    lle_editor = SettingProvider(LocallyLinearEmbeddingParametersEditor)
    spectral_editor = SettingProvider(SpectralEmbeddingParametersEditor)

    resizing_enabled = False
    want_main_area = False

    manifold_method_index = Setting(0)
    n_components = Setting(2)
    auto_apply = Setting(True)

    class Error(OWWidget.Error):
        n_neighbors_too_small = Msg("Neighbors must be greater than {}.")
        manifold_error = Msg("{}")
        sparse_not_supported = Msg("Sparse data is not supported.")

    def __init__(self):
        self.data = None

        # GUI
        method_box = gui.vBox(self.controlArea, "Method")
        self.manifold_methods_combo = gui.comboBox(
            method_box, self, "manifold_method_index",
            items=[m.name for m in self.MANIFOLD_METHODS],
            callback=self.manifold_method_changed)

        self.params_box = gui.vBox(self.controlArea, "Parameters")

        self.tsne_editor = TSNEParametersEditor(self)
        self.mds_editor = MDSParametersEditor(self)
        self.isomap_editor = IsomapParametersEditor(self)
        self.lle_editor = LocallyLinearEmbeddingParametersEditor(self)
        self.spectral_editor = SpectralEmbeddingParametersEditor(self)
        self.parameter_editors = [
            self.tsne_editor, self.mds_editor, self.isomap_editor,
            self.lle_editor, self.spectral_editor]

        for editor in self.parameter_editors:
            self.params_box.layout().addWidget(editor)
            editor.hide()
        self.params_widget = self.parameter_editors[self.manifold_method_index]
        self.params_widget.show()

        output_box = gui.vBox(self.controlArea, "Output")
        self.n_components_spin = gui.spin(
            output_box, self, "n_components", 1, 10, label="Components:",
            alignment=Qt.AlignRight, callbackOnReturn=True,
            callback=self.settings_changed)
        self.apply_button = gui.auto_commit(
            output_box, self, "auto_apply", "&Apply",
            box=False, commit=self.apply)

    def manifold_method_changed(self):
        self.params_widget.hide()
        self.params_widget = self.parameter_editors[self.manifold_method_index]
        self.params_widget.show()
        self.apply()

    def settings_changed(self):
        self.apply()

    def set_data(self, data):
        self.data = data
        self.n_components_spin.setMaximum(len(self.data.domain.attributes)
                                          if self.data else 10)
        self.apply()

    def apply(self):
        data = None
        self.clear_messages()
        if self.data:
            if self.data.is_sparse():
                self.Error.sparse_not_supported()
            else:
                with self.progressBar():
                    self.progressBarSet(10)
                    domain = Domain([ContinuousVariable("C{}".format(i))
                                     for i in range(self.n_components)],
                                    self.data.domain.class_vars,
                                    self.data.domain.metas)

                    method = self.MANIFOLD_METHODS[self.manifold_method_index]
                    projector = method(**self.get_method_parameters())
                    try:
                        self.progressBarSet(20)
                        X = projector(self.data).embedding_
                        data = Table(domain, X, self.data.Y, self.data.metas)
                    except ValueError as e:
                        if e.args[0] == "for method='hessian', n_neighbors " \
                                        "must be greater than [n_components" \
                                        " * (n_components + 3) / 2]":
                            n = self.n_components * (self.n_components + 3) / 2
                            self.Error.n_neighbors_too_small("{}".format(n))
                        else:
                            self.Error.manifold_error(e.args[0])
                    except np.linalg.linalg.LinAlgError as e:
                        self.Error.manifold_error(str(e))
        self.send("Transformed data", data)

    def get_method_parameters(self):
        parameters = dict(n_components=self.n_components)
        parameters.update(self.params_widget.parameters)
        return parameters

    def send_report(self):
        method_name = self.MANIFOLD_METHODS[self.manifold_method_index].name
        self.report_items((("Method", method_name),))
        parameters = self.get_method_parameters()
        self.report_items("Method parameters", tuple(parameters.items()))
        if self.data:
            self.report_data("Data", self.data)
Ejemplo n.º 31
0
class OWScatterPlot(OWDataProjectionWidget):
    """Scatterplot visualization with explorative analysis and intelligent
    data visualization enhancements."""

    name = 'Scatter Plot'
    description = "Interactive scatter plot visualization with " \
                  "intelligent data visualization enhancements."
    icon = "icons/ScatterPlot.svg"
    priority = 140
    keywords = []

    class Inputs(OWDataProjectionWidget.Inputs):
        features = Input("Features", AttributeList)

    class Outputs(OWDataProjectionWidget.Outputs):
        features = Output("Features", AttributeList, dynamic=False)

    settings_version = 3
    auto_sample = Setting(True)
    attr_x = ContextSetting(None)
    attr_y = ContextSetting(None)
    tooltip_shows_all = Setting(True)

    GRAPH_CLASS = OWScatterPlotGraph
    graph = SettingProvider(OWScatterPlotGraph)
    embedding_variables_names = None

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")

    class Information(OWDataProjectionWidget.Information):
        sampled_sql = Msg("Large SQL table; showing a sample.")
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        self.sql_data = None  # Orange.data.sql.table.SqlTable
        self.attribute_selection_list = None  # list of Orange.data.Variable
        self.__timer = QTimer(self, interval=1200)
        self.__timer.timeout.connect(self.add_data)
        super().__init__()

        # manually register Matplotlib file writers
        self.graph_writers = self.graph_writers.copy()
        for w in [MatplotlibFormat, MatplotlibPDFFormat]:
            for ext in w.EXTENSIONS:
                self.graph_writers[ext] = w

    def _add_controls(self):
        self._add_controls_axis()
        self._add_controls_sampling()
        super()._add_controls()
        self.gui.add_widget(self.gui.JitterNumericValues, self._effects_box)
        self.gui.add_widgets([
            self.gui.ShowGridLines, self.gui.ToolTipShowsAll,
            self.gui.RegressionLine
        ], self._plot_box)

    def _add_controls_axis(self):
        common_options = dict(labelWidth=50,
                              orientation=Qt.Horizontal,
                              sendSelectedValue=True,
                              valueType=str,
                              contentsLength=14)
        box = gui.vBox(self.controlArea, True)
        dmod = DomainModel
        self.xy_model = DomainModel(dmod.MIXED, valid_types=dmod.PRIMITIVE)
        self.cb_attr_x = gui.comboBox(box,
                                      self,
                                      "attr_x",
                                      label="Axis x:",
                                      callback=self.attr_changed,
                                      model=self.xy_model,
                                      **common_options)
        self.cb_attr_y = gui.comboBox(box,
                                      self,
                                      "attr_y",
                                      label="Axis y:",
                                      callback=self.attr_changed,
                                      model=self.xy_model,
                                      **common_options)
        vizrank_box = gui.hBox(box)
        self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank(
            vizrank_box, self, "Find Informative Projections", self.set_attr)

    def _add_controls_sampling(self):
        self.sampling = gui.auto_commit(self.controlArea,
                                        self,
                                        "auto_sample",
                                        "Sample",
                                        box="Sampling",
                                        callback=self.switch_sampling,
                                        commit=lambda: self.add_data(1))
        self.sampling.setVisible(False)

    @property
    def effective_variables(self):
        return [self.attr_x, self.attr_y]

    def _vizrank_color_change(self):
        self.vizrank.initialize()
        is_enabled = self.data is not None and not self.data.is_sparse() and \
            len(self.xy_model) > 2 and len(self.data[self.valid_data]) > 1 \
            and np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0)
        self.vizrank_button.setEnabled(
            is_enabled and self.attr_color is not None and not np.isnan(
                self.data.get_column_view(
                    self.attr_color)[0].astype(float)).all())
        text = "Color variable has to be selected." \
            if is_enabled and self.attr_color is None else ""
        self.vizrank_button.setToolTip(text)

    def set_data(self, data):
        if self.data and data and self.data.checksum() == data.checksum():
            return
        super().set_data(data)

        def findvar(name, iterable):
            """Find a Orange.data.Variable in `iterable` by name"""
            for el in iterable:
                if isinstance(el, Variable) and el.name == name:
                    return el
            return None

        # handle restored settings from  < 3.3.9 when attr_* were stored
        # by name
        if isinstance(self.attr_x, str):
            self.attr_x = findvar(self.attr_x, self.xy_model)
        if isinstance(self.attr_y, str):
            self.attr_y = findvar(self.attr_y, self.xy_model)
        if isinstance(self.attr_label, str):
            self.attr_label = findvar(self.attr_label, self.gui.label_model)
        if isinstance(self.attr_color, str):
            self.attr_color = findvar(self.attr_color, self.gui.color_model)
        if isinstance(self.attr_shape, str):
            self.attr_shape = findvar(self.attr_shape, self.gui.shape_model)
        if isinstance(self.attr_size, str):
            self.attr_size = findvar(self.attr_size, self.gui.size_model)

    def check_data(self):
        self.clear_messages()
        self.__timer.stop()
        self.sampling.setVisible(False)
        self.sql_data = None
        if isinstance(self.data, SqlTable):
            if self.data.approx_len() < 4000:
                self.data = Table(self.data)
            else:
                self.Information.sampled_sql()
                self.sql_data = self.data
                data_sample = self.data.sample_time(0.8, no_cache=True)
                data_sample.download_data(2000, partial=True)
                self.data = Table(data_sample)
                self.sampling.setVisible(True)
                if self.auto_sample:
                    self.__timer.start()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        x_data = self.get_column(self.attr_x, filter_valid=False)
        y_data = self.get_column(self.attr_y, filter_valid=False)
        if x_data is None or y_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_x.name, self.attr_y.name)
        return np.vstack((x_data, y_data)).T

    # Tooltip
    def _point_tooltip(self, point_id, skip_attrs=()):
        point_data = self.data[point_id]
        xy_attrs = (self.attr_x, self.attr_y)
        text = "<br/>".join(
            escape('{} = {}'.format(var.name, point_data[var]))
            for var in xy_attrs)
        if self.tooltip_shows_all:
            others = super()._point_tooltip(point_id, skip_attrs=xy_attrs)
            if others:
                text = "<b>{}</b><br/><br/>{}".format(text, others)
        return text

    def can_draw_regresssion_line(self):
        return self.data is not None and\
               self.data.domain is not None and \
               self.attr_x.is_continuous and \
               self.attr_y.is_continuous

    def add_data(self, time=0.4):
        if self.data and len(self.data) > 2000:
            self.__timer.stop()
            return
        data_sample = self.sql_data.sample_time(time, no_cache=True)
        if data_sample:
            data_sample.download_data(2000, partial=True)
            data = Table(data_sample)
            self.data = Table.concatenate((self.data, data), axis=0)
            self.handleNewSignals()

    def init_attr_values(self):
        super().init_attr_values()
        data = self.data
        domain = data.domain if data and len(data) else None
        self.xy_model.set_domain(domain)
        self.attr_x = self.xy_model[0] if self.xy_model else None
        self.attr_y = self.xy_model[1] if len(self.xy_model) >= 2 \
            else self.attr_x

    def switch_sampling(self):
        self.__timer.stop()
        if self.auto_sample and self.sql_data:
            self.add_data()
            self.__timer.start()

    def set_subset_data(self, subset_data):
        self.warning()
        if isinstance(subset_data, SqlTable):
            if subset_data.approx_len() < AUTO_DL_LIMIT:
                subset_data = Table(subset_data)
            else:
                self.warning("Data subset does not support large Sql tables")
                subset_data = None
        super().set_subset_data(subset_data)

    # called when all signals are received, so the graph is updated only once
    def handleNewSignals(self):
        if self.attribute_selection_list and self.data is not None and \
                self.data.domain is not None and \
                all(attr in self.data.domain for attr
                        in self.attribute_selection_list):
            self.set_attr(self.attribute_selection_list[0],
                          self.attribute_selection_list[1])
            self.attribute_selection_list = None
        else:
            super().handleNewSignals()
        self._vizrank_color_change()
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())

    @Inputs.features
    def set_shown_attributes(self, attributes):
        if attributes and len(attributes) >= 2:
            self.attribute_selection_list = attributes[:2]
        else:
            self.attribute_selection_list = None

    def set_attr(self, attr_x, attr_y):
        if attr_x != self.attr_x or attr_y != self.attr_y:
            self.attr_x, self.attr_y = attr_x, attr_y
            self.attr_changed()

    def attr_changed(self):
        self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())
        self.setup_plot()
        self.commit()

    def setup_plot(self):
        super().setup_plot()
        for axis, var in (("bottom", self.attr_x), ("left", self.attr_y)):
            self.graph.set_axis_title(axis, var)
            if var and var.is_discrete:
                self.graph.set_axis_labels(axis,
                                           get_variable_values_sorted(var))
            else:
                self.graph.set_axis_labels(axis, None)

    def colors_changed(self):
        super().colors_changed()
        self._vizrank_color_change()

    def commit(self):
        super().commit()
        self.send_features()

    def send_features(self):
        features = [attr for attr in [self.attr_x, self.attr_y] if attr]
        self.Outputs.features.send(features or None)

    def get_widget_name_extension(self):
        if self.data is not None:
            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
        return None

    def _get_send_report_caption(self):
        return report.render_items_vert(
            (("Color", self._get_caption_var_name(self.attr_color)),
             ("Label", self._get_caption_var_name(self.attr_label)),
             ("Shape", self._get_caption_var_name(self.attr_shape)),
             ("Size", self._get_caption_var_name(self.attr_size)),
             ("Jittering", (self.attr_x.is_discrete or self.attr_y.is_discrete
                            or self.graph.jitter_continuous)
              and self.graph.jitter_size)))

    @classmethod
    def migrate_settings(cls, settings, version):
        if version < 2 and "selection" in settings and settings["selection"]:
            settings["selection_group"] = [(a, 1)
                                           for a in settings["selection"]]
        if version < 3:
            if "auto_send_selection" in settings:
                settings["auto_commit"] = settings["auto_send_selection"]
            if "selection_group" in settings:
                settings["selection"] = settings["selection_group"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Ejemplo n.º 32
0
class OWMDS(OWDataProjectionWidget):
    name = "多维标度分析"
    description = "由距离矩阵构造的多维标度的二维数据投影。"
    icon = "icons/MDS.svg"
    keywords = ["multidimensional scaling", "multi dimensional scaling"]

    class Inputs(OWDataProjectionWidget.Inputs):
        distances = Input("Distances", DistMatrix)

    settings_version = 3

    #: Initialization type
    PCA, Random, Jitter = 0, 1, 2

    #: Refresh rate
    RefreshRate = [("Every iteration", 1), ("Every 5 steps", 5),
                   ("Every 10 steps", 10), ("Every 25 steps", 25),
                   ("Every 50 steps", 50), ("None", -1)]

    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    max_iter = settings.Setting(300)
    initialization = settings.Setting(PCA)
    refresh_rate = settings.Setting(3)

    GRAPH_CLASS = OWMDSGraph
    graph = SettingProvider(OWMDSGraph)
    embedding_variables_names = ("mds-x", "mds-y")

    class Error(OWDataProjectionWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        matrix_too_small = Msg("Input matrix must be at least 2x2")
        no_attributes = Msg("Data has no attributes")
        mismatching_dimensions = \
            Msg("Data and distances dimensions do not match.")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")

    def __init__(self):
        super().__init__()
        #: Input dissimilarity matrix
        self.matrix = None  # type: Optional[DistMatrix]
        #: Data table from the `self.matrix.row_items` (if present)
        self.matrix_data = None  # type: Optional[Table]
        #: Input data table
        self.signal_data = None

        self.__invalidated = True
        self.embedding = None
        self.effective_matrix = None

        self.__update_loop = None
        # timer for scheduling updates
        self.__timer = QTimer(self, singleShot=True, interval=0)
        self.__timer.timeout.connect(self.__next_step)
        self.__state = OWMDS.Waiting
        self.__in_next_step = False

        self.graph.pause_drawing_pairs()

        self.size_model = self.gui.points_models[2]
        self.size_model.order = \
            self.gui.points_models[2].order[:1] \
            + ("Stress", ) + \
            self.gui.points_models[2].order[1:]
        # self._initialize()

    def _add_controls(self):
        self._add_controls_optimization()
        super()._add_controls()
        self.gui.add_control(self._effects_box,
                             gui.hSlider,
                             "Show similar pairs:",
                             master=self.graph,
                             value="connected_pairs",
                             minValue=0,
                             maxValue=20,
                             createLabel=False,
                             callback=self._on_connected_changed)

    def _add_controls_optimization(self):
        box = gui.vBox(self.controlArea, box=True)
        self.runbutton = gui.button(box,
                                    self,
                                    "运行优化",
                                    callback=self._toggle_run)
        gui.comboBox(box,
                     self,
                     "refresh_rate",
                     label="刷新: ",
                     orientation=Qt.Horizontal,
                     items=[t for t, _ in OWMDS.RefreshRate],
                     callback=self.__invalidate_refresh)
        hbox = gui.hBox(box, margin=0)
        gui.button(hbox, self, "PCA", callback=self.do_PCA)
        gui.button(hbox, self, "Randomize", callback=self.do_random)
        gui.button(hbox, self, "Jitter", callback=self.do_jitter)

    def set_data(self, data):
        """Set the input dataset.

        Parameters
        ----------
        data : Optional[Table]
        """
        if data is not None and len(data) < 2:
            self.Error.not_enough_rows()
            data = None
        else:
            self.Error.not_enough_rows.clear()

        self.signal_data = data

    @Inputs.distances
    def set_disimilarity(self, matrix):
        """Set the dissimilarity (distance) matrix.

        Parameters
        ----------
        matrix : Optional[Orange.misc.DistMatrix]
        """

        if matrix is not None and len(matrix) < 2:
            self.Error.matrix_too_small()
            matrix = None
        else:
            self.Error.matrix_too_small.clear()

        self.matrix = matrix
        self.matrix_data = matrix.row_items if matrix is not None else None

    def clear(self):
        super().clear()
        self.embedding = None
        self.graph.set_effective_matrix(None)
        self.__set_update_loop(None)
        self.__state = OWMDS.Waiting

    def _initialize(self):
        matrix_existed = self.effective_matrix is not None
        effective_matrix = self.effective_matrix
        self.__invalidated = True
        self.data = None
        self.effective_matrix = None
        self.closeContext()
        self.clear_messages()

        # if no data nor matrix is present reset plot
        if self.signal_data is None and self.matrix is None:
            self.clear()
            self.init_attr_values()
            return

        if self.signal_data is not None and self.matrix is not None and \
                len(self.signal_data) != len(self.matrix):
            self.Error.mismatching_dimensions()
            self.clear()
            self.init_attr_values()
            return

        if self.signal_data is not None:
            self.data = self.signal_data
        elif self.matrix_data is not None:
            self.data = self.matrix_data

        if self.matrix is not None:
            self.effective_matrix = self.matrix
            if self.matrix.axis == 0 and self.data is self.matrix_data:
                self.data = None
        elif self.data.domain.attributes:
            preprocessed_data = MDS().preprocess(self.data)
            self.effective_matrix = Euclidean(preprocessed_data)
        else:
            self.Error.no_attributes()
            self.clear()
            self.init_attr_values()
            return

        self.init_attr_values()
        self.openContext(self.data)
        self.__invalidated = not (
            matrix_existed and self.effective_matrix is not None
            and np.array_equal(effective_matrix, self.effective_matrix))
        if self.__invalidated:
            self.clear()
        self.graph.set_effective_matrix(self.effective_matrix)

    def _toggle_run(self):
        if self.__state == OWMDS.Running:
            self.stop()
            self._invalidate_output()
        else:
            self.start()

    def start(self):
        if self.__state == OWMDS.Running:
            return
        elif self.__state == OWMDS.Finished:
            # Resume/continue from a previous run
            self.__start()
        elif self.__state == OWMDS.Waiting and \
                self.effective_matrix is not None:
            self.__start()

    def stop(self):
        if self.__state == OWMDS.Running:
            self.__set_update_loop(None)

    def __start(self):
        self.graph.pause_drawing_pairs()
        X = self.effective_matrix
        init = self.embedding

        # number of iterations per single GUI update step
        _, step_size = OWMDS.RefreshRate[self.refresh_rate]
        if step_size == -1:
            step_size = self.max_iter

        def update_loop(X, max_iter, step, init):
            """
            return an iterator over successive improved MDS point embeddings.
            """
            # NOTE: this code MUST NOT call into QApplication.processEvents
            done = False
            iterations_done = 0
            oldstress = np.finfo(np.float).max
            init_type = "PCA" if self.initialization == OWMDS.PCA else "random"

            while not done:
                step_iter = min(max_iter - iterations_done, step)
                mds = MDS(dissimilarity="precomputed",
                          n_components=2,
                          n_init=1,
                          max_iter=step_iter,
                          init_type=init_type,
                          init_data=init)

                mdsfit = mds(X)
                iterations_done += step_iter

                embedding, stress = mdsfit.embedding_, mdsfit.stress_
                stress /= np.sqrt(np.sum(embedding**2, axis=1)).sum()

                if iterations_done >= max_iter:
                    done = True
                elif (oldstress - stress) < mds.params["eps"]:
                    done = True
                init = embedding
                oldstress = stress

                yield embedding, mdsfit.stress_, iterations_done / max_iter

        self.__set_update_loop(update_loop(X, self.max_iter, step_size, init))
        self.progressBarInit(processEvents=None)

    def __set_update_loop(self, loop):
        """
        Set the update `loop` coroutine.

        The `loop` is a generator yielding `(embedding, stress, progress)`
        tuples where `embedding` is a `(N, 2) ndarray` of current updated
        MDS points, `stress` is the current stress and `progress` a float
        ratio (0 <= progress <= 1)

        If an existing update coroutine loop is already in place it is
        interrupted (i.e. closed).

        .. note::
            The `loop` must not explicitly yield control flow to the event
            loop (i.e. call `QApplication.processEvents`)

        """
        if self.__update_loop is not None:
            self.__update_loop.close()
            self.__update_loop = None
            self.progressBarFinished(processEvents=None)

        self.__update_loop = loop

        if loop is not None:
            self.setBlocking(True)
            self.progressBarInit(processEvents=None)
            self.setStatusMessage("Running")
            self.runbutton.setText("Stop")
            self.__state = OWMDS.Running
            self.__timer.start()
        else:
            self.setBlocking(False)
            self.setStatusMessage("")
            self.runbutton.setText("Start")
            self.__state = OWMDS.Finished
            self.__timer.stop()

    def __next_step(self):
        if self.__update_loop is None:
            return

        assert not self.__in_next_step
        self.__in_next_step = True

        loop = self.__update_loop
        self.Error.out_of_memory.clear()
        try:
            embedding, _, progress = next(self.__update_loop)
            assert self.__update_loop is loop
        except StopIteration:
            self.__set_update_loop(None)
            self.unconditional_commit()
            self.graph.resume_drawing_pairs()
        except MemoryError:
            self.Error.out_of_memory()
            self.__set_update_loop(None)
            self.graph.resume_drawing_pairs()
        except Exception as exc:
            self.Error.optimization_error(str(exc))
            self.__set_update_loop(None)
            self.graph.resume_drawing_pairs()
        else:
            self.progressBarSet(100.0 * progress, processEvents=None)
            self.embedding = embedding
            self.graph.update_coordinates()
            # schedule next update
            self.__timer.start()

        self.__in_next_step = False

    def do_PCA(self):
        self.__invalidate_embedding(self.PCA)

    def do_random(self):
        self.__invalidate_embedding(self.Random)

    def do_jitter(self):
        self.__invalidate_embedding(self.Jitter)

    def __invalidate_embedding(self, initialization=PCA):
        def jitter_coord(part):
            span = np.max(part) - np.min(part)
            part += np.random.uniform(-span / 20, span / 20, len(part))

        # reset/invalidate the MDS embedding, to the default initialization
        # (Random or PCA), restarting the optimization if necessary.
        state = self.__state
        if self.__update_loop is not None:
            self.__set_update_loop(None)

        if self.effective_matrix is None:
            self.graph.reset_graph()
            return

        X = self.effective_matrix

        if initialization == OWMDS.PCA:
            self.embedding = torgerson(X)
        elif initialization == OWMDS.Random:
            self.embedding = np.random.rand(len(X), 2)
        else:
            jitter_coord(self.embedding[:, 0])
            jitter_coord(self.embedding[:, 1])

        self.setup_plot()

        # restart the optimization if it was interrupted.
        if state == OWMDS.Running:
            self.__start()

    def __invalidate_refresh(self):
        state = self.__state

        if self.__update_loop is not None:
            self.__set_update_loop(None)

        # restart the optimization if it was interrupted.
        # TODO: decrease the max iteration count by the already
        # completed iterations count.
        if state == OWMDS.Running:
            self.__start()

    def handleNewSignals(self):
        self._initialize()
        if self.__invalidated:
            self.graph.pause_drawing_pairs()
            self.__invalidated = False
            self.__invalidate_embedding()
            self.cb_class_density.setEnabled(self.can_draw_density())
            self.start()
        else:
            self.graph.update_point_props()
        self.commit()

    def _invalidate_output(self):
        self.commit()

    def _on_connected_changed(self):
        self.graph.set_effective_matrix(self.effective_matrix)
        self.graph.update_pairs(reconnect=True)

    def setup_plot(self):
        super().setup_plot()
        if self.embedding is not None:
            self.graph.update_pairs(reconnect=True)

    def get_size_data(self):
        if self.attr_size == "Stress":
            return stress(self.embedding, self.effective_matrix)
        else:
            return super().get_size_data()

    def get_embedding(self):
        self.valid_data = np.ones(len(self.embedding), dtype=bool) \
            if self.embedding is not None else None
        return self.embedding

    def _get_projection_data(self):
        if self.embedding is None:
            return None

        if self.data is None:
            x_name, y_name = self.embedding_variables_names
            variables = ContinuousVariable(x_name), ContinuousVariable(y_name)
            return Table(Domain(variables), self.embedding)
        return super()._get_projection_data()

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_graph = {}
            for old, new in (("label_only_selected", "label_only_selected"),
                             ("symbol_opacity", "alpha_value"),
                             ("symbol_size", "point_width"), ("jitter",
                                                              "jitter_size")):
                settings_graph[new] = settings_[old]
            settings_["graph"] = settings_graph
            settings_["auto_commit"] = settings_["autocommit"]

        if version < 3:
            if "connected_pairs" in settings_:
                connected_pairs = settings_["connected_pairs"]
                settings_["graph"]["connected_pairs"] = connected_pairs

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            n_domain = [t for t in context.ordered_domain if t[1] == 2]
            c_domain = [t for t in context.ordered_domain if t[1] == 1]
            context_values = {}
            for _, old_val, new_val in ((domain, "color_value", "attr_color"),
                                        (c_domain, "shape_value",
                                         "attr_shape"),
                                        (n_domain, "size_value", "attr_size"),
                                        (domain, "label_value", "attr_label")):
                tmp = context.values[old_val]
                if tmp[1] >= 0:
                    context_values[new_val] = (tmp[0], tmp[1] + 100)
                elif tmp[0] != "Stress":
                    context_values[new_val] = None
                else:
                    context_values[new_val] = tmp
            context.values = context_values

        if version < 3 and "graph" in context.values:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]