Exemple #1
0
class POSTaggingModule(SingleMethodModule):
    title = 'POS Tagger'
    attribute = 'pos_tagger'
    enabled = settings.Setting(False)

    stanford = settings.SettingProvider(ResourceLoader)

    methods = [AveragedPerceptronTagger, MaxEntTagger, StanfordPOSTagger]
    STANFORD = 2

    initialize_methods = False

    def setup_method_layout(self):
        super().setup_method_layout()
        # initialize all methods except StanfordPOSTagger
        # cannot be done in superclass due to StanfordPOSTagger
        self.methods = [method() for method in self.methods[:self.STANFORD]]

        self.stanford = ResourceLoader(
            widget=self.master,
            model_format='Stanford model (*.model *.tagger)',
            provider_format='Java file (*.jar)',
            model_button_label='Model',
            provider_button_label='Tagger')
        self.set_stanford_tagger(self.stanford.model_path,
                                 self.stanford.resource_path,
                                 silent=True)

        self.stanford.valueChanged.connect(self.set_stanford_tagger)
        self.method_layout.addWidget(self.stanford, self.STANFORD, 1)

    def set_stanford_tagger(self, model_path, stanford_path, silent=False):
        self.master.Error.stanford_tagger.clear()
        valid = False
        if model_path and stanford_path:
            try:
                self.stanford_tagger.check(model_path, stanford_path)
                self.methods[self.STANFORD] = StanfordPOSTagger(
                    model_path, stanford_path)
                valid = True
                self.update_value()
            except ValueError as e:
                if not silent:
                    self.master.Error.stanford(str(e))

        self.group.button(self.STANFORD).setChecked(valid)
        self.group.button(self.STANFORD).setEnabled(valid)

        if not stanford_path:
            self.stanford.provider_widget.browse_button.setStyleSheet(
                "color:#C00;")
        else:
            self.stanford.provider_widget.browse_button.setStyleSheet(
                "color:black;")

    @property
    def stanford_tagger(self):
        return self.methods[self.STANFORD]
Exemple #2
0
class OWPreprocess(OWWidget):

    name = 'Preprocess Text'
    description = 'Construct a text pre-processing pipeline.'
    icon = 'icons/TextPreprocess.svg'
    priority = 30

    class Inputs:
        corpus = Input("Corpus", Corpus)

    class Outputs:
        corpus = Output("Corpus", Corpus)

    autocommit = settings.Setting(True)

    preprocessors = [
        TransformationModule,
        TokenizerModule,
        NormalizationModule,
        FilteringModule,
        NgramsModule,
        POSTaggingModule,
    ]

    transformers = settings.SettingProvider(TransformationModule)
    tokenizer = settings.SettingProvider(TokenizerModule)
    normalizer = settings.SettingProvider(NormalizationModule)
    filters = settings.SettingProvider(FilteringModule)
    ngrams_range = settings.SettingProvider(NgramsModule)
    pos_tagger = settings.SettingProvider(POSTaggingModule)

    control_area_width = 250
    buttons_area_orientation = Qt.Vertical

    UserAdviceMessages = [
        widget.Message(
            "Some preprocessing methods require data (like word relationships, stop words, "
            "punctuation rules etc.) from the NLTK package. This data was downloaded "
            "to: {}".format(nltk_data_dir()), "nltk_data")
    ]

    class Error(OWWidget.Error):
        stanford_tagger = Msg("Problem while loading Stanford POS Tagger\n{}")

    class Warning(OWWidget.Warning):
        no_token_left = Msg(
            'No tokens on output! Please, change configuration.')

    def __init__(self, parent=None):
        super().__init__(parent)
        self.corpus = None
        self.initial_ngram_range = None  # initial range of input corpus — used for inplace
        self.preprocessor = preprocess.Preprocessor()

        # -- INFO --
        info_box = gui.widgetBox(self.controlArea, 'Info')
        info_box.setFixedWidth(self.control_area_width)
        self.controlArea.layout().addStretch()
        self.info_label = gui.label(info_box, self, '')
        self.update_info()

        # -- PIPELINE --
        frame = QFrame()
        frame.setContentsMargins(0, 0, 0, 0)
        frame.setFrameStyle(QFrame.Box)
        frame.setStyleSheet('.QFrame { border: 1px solid #B3B3B3; }')
        frame_layout = QVBoxLayout()
        frame_layout.setContentsMargins(0, 0, 0, 0)
        frame_layout.setSpacing(0)
        frame.setLayout(frame_layout)

        self.stages = []
        for stage in self.preprocessors:
            widget = stage(self)
            self.stages.append(widget)
            setattr(self, stage.attribute, widget)
            frame_layout.addWidget(widget)
            widget.change_signal.connect(self.settings_invalidated)

        frame_layout.addStretch()
        self.scroll = QScrollArea()
        self.scroll.setWidget(frame)
        self.scroll.setWidgetResizable(True)
        self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.scroll.resize(frame_layout.sizeHint())
        self.scroll.setMinimumHeight(500)
        self.set_minimal_width()
        self.mainArea.layout().addWidget(self.scroll)

        # Buttons area
        self.report_button.setFixedWidth(self.control_area_width)

        commit_button = gui.auto_commit(self.buttonsArea,
                                        self,
                                        'autocommit',
                                        'Commit',
                                        box=False)
        commit_button.setFixedWidth(self.control_area_width - 5)

        self.buttonsArea.layout().addWidget(commit_button)

    @Inputs.corpus
    def set_data(self, data=None):
        self.corpus = data.copy() if data is not None else None
        self.initial_ngram_range = data.ngram_range if data is not None else None
        self.commit()

    def update_info(self, corpus=None):
        if corpus is not None:
            info = 'Document count: {}\n' \
                   'Total tokens: {}\n'\
                   'Total types: {}'\
                   .format(len(corpus), sum(map(len, corpus.tokens)), len(corpus.dictionary))
        else:
            info = 'No corpus.'
        self.info_label.setText(info)

    def commit(self):
        self.Warning.no_token_left.clear()
        if self.corpus is not None:
            self.apply()
        else:
            self.update_info()
            self.Outputs.corpus.send(None)

    def apply(self):
        self.preprocess()

    @asynchronous
    def preprocess(self):
        for module in self.stages:
            setattr(self.preprocessor, module.attribute, module.value)
        self.corpus.pos_tags = None  # reset pos_tags and ngrams_range
        self.corpus.ngram_range = self.initial_ngram_range
        return self.preprocessor(self.corpus,
                                 inplace=True,
                                 on_progress=self.on_progress)

    @preprocess.on_start
    def on_start(self):
        self.progressBarInit(None)

    @preprocess.callback
    def on_progress(self, i):
        self.progressBarSet(i, None)

    @preprocess.on_result
    def on_result(self, result):
        self.update_info(result)
        if result is not None and len(result.dictionary) == 0:
            self.Warning.no_token_left()
            result = None
        self.Outputs.corpus.send(result)
        self.progressBarFinished(None)

    def set_minimal_width(self):
        max_width = 250
        for widget in self.stages:
            if widget.enabled:
                max_width = max(max_width, widget.sizeHint().width())
        self.scroll.setMinimumWidth(max_width + 20)

    @pyqtSlot()
    def settings_invalidated(self):
        self.set_minimal_width()
        self.commit()

    def send_report(self):
        self.report_items('Preprocessor', self.preprocessor.report())
Exemple #3
0
class OWGenialisExpressions(widget.OWWidget, ConcurrentWidgetMixin):
    name = 'Genialis Expressions'
    priority = 30
    want_main_area = True
    want_control_area = True
    icon = '../widgets/icons/OWGenialisExpressions.svg'

    pagination_availability = pyqtSignal(bool, bool)

    norm_component = settings.SettingProvider(NormalizationComponent)
    pagination_component = settings.SettingProvider(PaginationComponent)
    filter_component = settings.SettingProvider(CollapsibleFilterComponent)

    exp_type: int
    exp_type = settings.Setting(None, schema_only=True)

    proc_type: int
    proc_type = settings.Setting(None, schema_only=True)

    input_annotation: int
    input_annotation = settings.Setting(None, schema_only=True)

    auto_commit: bool
    auto_commit = settings.Setting(False, schema_only=True)

    class Outputs:
        table = Output('Expressions', Table)

    class Warning(widget.OWWidget.Warning):
        no_expressions = Msg('Expression data objects not found.')
        no_data_objects = Msg('No expression data matches the selected filtering options.')
        unexpected_feature_type = Msg('Can not import expression data, unexpected feature type "{}".')
        multiple_feature_type = Msg('Can not import expression data, multiple feature types found.')

    def __init__(self):
        super().__init__()
        ConcurrentWidgetMixin.__init__(self)

        self._res = None
        self._data_objects: Optional[List[Data]] = None
        self.data_output_options: Optional[DataOutputOptions] = None
        self.data_table: Optional[Table] = None

        # Control area
        box = gui.widgetBox(self.controlArea, 'Sign in')
        self.user_info = gui.label(box, self, '')
        self.server_info = gui.label(box, self, '')

        box = gui.widgetBox(box, orientation=Qt.Horizontal)
        self.sign_in_btn = gui.button(box, self, 'Sign In', callback=self.sign_in, autoDefault=False)
        self.sign_out_btn = gui.button(box, self, 'Sign Out', callback=self.sign_out, autoDefault=False)

        self.exp_type_box = gui.widgetBox(self.controlArea, 'Expression Type')
        self.exp_type_options = gui.radioButtons(
            self.exp_type_box, self, 'exp_type', callback=self.on_data_output_option_changed
        )

        self.proc_type_box = gui.widgetBox(self.controlArea, 'Process Name')
        self.proc_type_options = gui.radioButtons(
            self.proc_type_box, self, 'proc_type', callback=self.on_data_output_option_changed
        )

        self.input_anno_box = gui.widgetBox(self.controlArea, 'Expression source')
        self.input_anno_options = gui.radioButtons(
            self.input_anno_box, self, 'input_annotation', callback=self.on_data_output_option_changed
        )

        self.norm_component = NormalizationComponent(self, self.controlArea)
        self.norm_component.options_changed.connect(self.on_normalization_changed)

        gui.rubber(self.controlArea)
        self.commit_button = gui.auto_commit(self.controlArea, self, 'auto_commit', '&Commit', box=False)
        self.commit_button.button.setAutoDefault(False)

        # Main area
        self.table_view = QTableView()
        self.table_view.setAlternatingRowColors(True)
        self.table_view.viewport().setMouseTracking(True)
        self.table_view.setShowGrid(False)
        self.table_view.verticalHeader().hide()
        self.table_view.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
        self.table_view.horizontalHeader().setStretchLastSection(True)
        self.table_view.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.table_view.setSelectionMode(QAbstractItemView.SingleSelection)
        # self.table_view.setStyleSheet('QTableView::item:selected{background-color: palette(highlight); color: palette(highlightedText);};')

        self.model = GenialisExpressionsModel(self)
        self.model.setHorizontalHeaderLabels(TableHeader.labels())
        self.table_view.setModel(self.model)
        self.table_view.selectionModel().selectionChanged.connect(self.on_selection_changed)

        self.filter_component = CollapsibleFilterComponent(self, self.mainArea)
        self.filter_component.options_changed.connect(self.on_filter_changed)
        self.mainArea.layout().addWidget(self.table_view)
        self.pagination_component = PaginationComponent(self, self.mainArea)
        self.pagination_component.options_changed.connect(self.update_collections_view)

        self.sign_in(silent=True)

    def __invalidate(self):
        self.data_objects = None
        self.data_table = None
        self.Warning.no_expressions.clear()
        self.Warning.multiple_feature_type.clear()
        self.Warning.unexpected_feature_type.clear()
        self.info.set_output_summary(StateInfo.NoOutput)

    def set_input_annotation_options(self) -> None:
        for btn in self.input_anno_options.buttons:
            btn.deleteLater()
        self.input_anno_options.buttons = []

        if not self.data_output_options:
            return

        for source, species, build in self.data_output_options.input_annotation:
            tooltip = f'{source}, {species}, {build}'
            text = f'{species}, {build}'
            gui.appendRadioButton(self.input_anno_options, text, tooltip=tooltip)

        if len(self.input_anno_options.buttons):
            self.input_annotation = 0

    def set_proc_type_options(self) -> None:
        for btn in self.proc_type_options.buttons:
            btn.deleteLater()
        self.proc_type_options.buttons = []

        if not self.data_output_options:
            return

        for proc_type, proc_name in self.data_output_options.process:
            gui.appendRadioButton(self.proc_type_options, proc_name, tooltip=proc_type)

        if len(self.proc_type_options.buttons):
            self.proc_type = 0

    def set_exp_type_options(self) -> None:
        for btn in self.exp_type_options.buttons:
            btn.deleteLater()
        self.exp_type_options.buttons = []

        if not self.data_output_options:
            return

        for _, exp_name in self.data_output_options.expression:
            gui.appendRadioButton(self.exp_type_options, exp_name)

        if len(self.exp_type_options.buttons) > 1:
            self.exp_type = 1

    @property
    def res(self):
        return self._res

    @res.setter
    def res(self, value: ResolweAPI):
        if isinstance(value, ResolweAPI):
            self._res = value
            self.update_user_status()
            self.update_collections_view()
            self.__invalidate()
            self.Outputs.table.send(None)

    @property
    def data_objects(self):
        return self._data_objects

    @data_objects.setter
    def data_objects(self, data_objects: Optional[List[Data]]):
        self._data_objects = data_objects
        self.data_output_options = self._available_data_output_options()

    def _available_data_output_options(self) -> Optional[DataOutputOptions]:
        """
        Traverse the data objects in the selected collection and store the
        information regarding available expression types, process types and
        input annotations used in the creation of the data object.

        The method returns a named tuple (`DataOutputOptions`) which used for
        creating radio buttons in the control area.
        """
        if not self.data_objects:
            return

        expression_types = sorted({data.output['exp_type'] for data in self.data_objects})
        expression_types = (Expression('rc', 'Read Counts'),) + tuple(
            Expression(exp_type, exp_type) for exp_type in expression_types
        )

        process_types = sorted({(data.process.type, data.process.name) for data in self.data_objects})
        process_types = tuple(Process(proc_type, proc_name) for proc_type, proc_name in process_types)

        input_annotations = sorted(
            {(data.output['source'], data.output['species'], data.output['build']) for data in self.data_objects}
        )
        input_annotations = tuple(
            InputAnnotation(source, species, build) for source, species, build in input_annotations
        )

        return DataOutputOptions(
            expression=expression_types, process=process_types, input_annotation=input_annotations
        )

    def update_user_status(self):
        user = self.res.get_currently_logged_user()

        if user:
            user_info = f"{user[0].get('first_name', '')} {user[0].get('last_name', '')}".strip()
            user_info = f"User: {user_info if user_info else user[0].get('username', '')}"
            self.sign_in_btn.setEnabled(False)
            self.sign_out_btn.setEnabled(True)
        else:
            user_info = 'User: Anonymous'
            self.sign_in_btn.setEnabled(True)
            self.sign_out_btn.setEnabled(False)

        self.user_info.setText(user_info)
        self.server_info.setText(f'Server: {self.res.url[8:]}')

    def sign_in(self, silent=False):
        dialog = SignInForm(self)

        if silent:
            dialog.sign_in()
            if dialog.resolwe_instance is not None:
                self.res = dialog.resolwe_instance
            else:
                self.res = connect(url=DEFAULT_URL)

        if not silent and dialog.exec_():
            self.res = dialog.resolwe_instance

    def sign_out(self):
        # Use public credentials when user signs out
        self.res = connect(url=DEFAULT_URL)
        # Remove username and
        cm = CredentialManager(CREDENTIAL_MANAGER_SERVICE)
        del cm.username
        del cm.password

    def on_filter_changed(self):
        self.pagination_component.reset_pagination()
        self.update_collections_view()

    def get_query_parameters(self) -> Dict[str, str]:
        params = {
            'limit': ItemsPerPage.values()[self.pagination_component.items_per_page],
            'offset': self.pagination_component.offset,
            'ordering': SortBy.values()[self.filter_component.sort_by],
        }

        if self.filter_component.filter_by_full_text:
            params.update({'text': self.filter_component.filter_by_full_text})

        if self.filter_component.filter_by_name:
            params.update({'name__icontains': self.filter_component.filter_by_name})

        if self.filter_component.filter_by_contrib:
            params.update({'contributor_name': self.filter_component.filter_by_contrib})

        if self.filter_component.filter_by_owner:
            params.update({'owners_name': self.filter_component.filter_by_owner})

        last_modified = FilterByDateModified.values()[self.filter_component.filter_by_modified]
        if last_modified:
            params.update({'modified__gte': last_modified.isoformat()})

        return params

    def get_collections(self) -> Tuple[Dict[str, str], Dict[str, str]]:
        # Get response from the server
        collections = self.res.get_collections(**self.get_query_parameters())
        # Loop trough collections and store ids
        collection_ids = [collection['id'] for collection in collections.get('results', [])]
        # Get species by collection ids
        collection_to_species = self.res.get_species(collection_ids)

        return collections, collection_to_species

    def update_collections_view(self):
        collections, collection_to_species = self.get_collections()

        # Pass the results to data model
        self.model.set_data(collections.get('results', []), collection_to_species)
        self.table_view.setItemDelegateForColumn(TableHeader.id, gui.LinkStyledItemDelegate(self.table_view))
        self.table_view.setColumnHidden(TableHeader.slug, True)
        self.table_view.setColumnHidden(TableHeader.tags, True)

        # Check pagination parameters and emit pagination_availability signal
        next_page = True if collections.get('next') else False
        previous_page = True if collections.get('previous') else False
        self.pagination_availability.emit(next_page, previous_page)

    def normalize(self, table: Table) -> Optional[Table]:
        if not table:
            return

        if self.norm_component.quantile_norm:
            table = QuantileNormalization()(table)

        if self.norm_component.log_norm:
            table = LogarithmicScale()(table)

        if self.norm_component.z_score_norm:
            table = ZScore(axis=self.norm_component.z_score_axis)(table)

        if self.norm_component.quantile_transform:
            axis = self.norm_component.quantile_transform_axis
            quantiles = min(table.X.shape[int(not axis)], 100)
            distribution = QuantileTransformDist.values()[self.norm_component.quantile_transform_dist]
            table = QuantileTransform(axis=axis, n_quantiles=quantiles, output_distribution=distribution)(table)

        return table

    def commit(self):
        self.Warning.no_data_objects.clear()
        self.cancel()

        if self.data_objects and not self.data_table:
            self.start(
                runner,
                self.res,
                self.data_objects,
                self.data_output_options,
                self.exp_type,
                self.proc_type,
                self.input_annotation,
            )
        else:
            self.Outputs.table.send(self.normalize(self.data_table))

    def on_data_output_option_changed(self):
        self.data_table = None

        if self.data_objects:
            self.commit()

    def on_normalization_changed(self):
        if self.data_objects:
            self.commit()

    def on_selection_changed(self):
        self.__invalidate()

        collection_id: str = self.get_selected_row_data(TableHeader.id)
        if not collection_id:
            return

        self.data_objects = self.res.get_expression_data_objects(collection_id)
        self.set_exp_type_options()
        self.set_proc_type_options()
        self.set_input_annotation_options()

        if not self.data_objects:
            self.Warning.no_expressions()
            return

        # Note: This here is to handle an edge case where we get
        #       different 'feature_type' data object in a collection.
        #       For now we raise a warning, but in the future we should
        #       discuss about how to properly handle different types of features.
        feature_types = {data.output['feature_type'] for data in self.data_objects}

        if len(feature_types) == 1 and 'gene' not in feature_types:
            self.Warning.unexpected_feature_type(feature_types.pop())
            self.data_objects = []
            return

        if len(feature_types) > 1:
            self.Warning.multiple_feature_type()
            self.data_objects = []
            return

        self.commit()

    def get_selected_row_data(self, column: int) -> Optional[str]:
        selection_model = self.table_view.selectionModel()
        rows = selection_model.selectedRows(column=column)
        if not rows:
            return

        return rows[0].data()

    def on_done(self, table: Table):
        if table:
            samples, genes = table.X.shape
            self.info.set_output_summary(f'Samples: {samples} Genes: {genes}')
            self.data_table = table
            self.Outputs.table.send(self.normalize(table))

    def on_exception(self, ex):
        if isinstance(ex, ResolweDataObjectsNotFound):
            self.Warning.no_data_objects()
            self.Outputs.table.send(None)
            self.data_table = None
            self.info.set_output_summary(StateInfo.NoOutput)
        else:
            raise ex

    def on_partial_result(self, result: Any) -> None:
        pass

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def sizeHint(self):
        return QSize(1280, 620)
Exemple #4
0
class OWFreeViz(OWProjectionWidget):
    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240
    keywords = ["viz"]

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    settings_version = 3
    settingsHandler = settings.DomainContextHandler()

    initialization = settings.Setting(InitType.Circular)
    auto_commit = settings.Setting(True)

    graph = settings.SettingProvider(OWFreeVizGraph)
    graph_name = "graph.plot_widget.plotItem"

    class Error(OWProjectionWidget.Error):
        sparse_data = widget.Msg("Sparse data is not supported")
        no_class_var = widget.Msg("Need a class variable")
        not_enough_class_vars = widget.Msg(
            "Needs discrete class variable with at lest 2 values"
        )
        features_exceeds_instances = widget.Msg(
            "Algorithm should not be used when number of features "
            "exceeds the number of instances."
        )
        too_many_data_instances = widget.Msg("Cannot handle so large data.")
        no_valid_data = widget.Msg("No valid data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self.subset_indices = None
        self._embedding_coords = None
        self._X = None
        self._Y = None
        self._rand_indices = None
        self.variable_x = ContinuousVariable("freeviz-x")
        self.variable_y = ContinuousVariable("freeviz-y")

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWFreeVizGraph(self, box)
        box.layout().addWidget(self.graph.plot_widget)

        box = gui.vBox(self.controlArea, box=True)
        gui.comboBox(box, self, "initialization", label="Initialization:",
                     items=InitType.items(), orientation=Qt.Horizontal,
                     labelWidth=90, callback=self.__init_combo_changed)
        self.btn_start = gui.button(box, self, "Optimize", self.__toggle_start,
                                    enabled=False)

        g = self.graph.gui
        g.point_properties_box(self.controlArea)
        box = g.effects_box(self.controlArea)
        g.add_control(box, gui.hSlider, "Hide radius:",
            master=self.graph, value="radius",
            minValue=0, maxValue=100,
            step=10, createLabel=False,
            callback=self.__radius_slider_changed)
        g.plot_properties_box(self.controlArea)

        self.controlArea.layout().addStretch(100)
        self.graph.box_zoom_select(self.controlArea)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

        # FreeViz
        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

        self.graph.view_box.started.connect(self._randomize_indices)
        self.graph.view_box.moved.connect(self._manual_move)
        self.graph.view_box.finished.connect(self._finish_manual_move)

    def __radius_slider_changed(self):
        self.graph.update_radius()

    def __toggle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            self.btn_start.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def __init_combo_changed(self):
        running = self._loop.isRunning()
        if running:
            self._loop.cancel()
        if self.data is not None:
            self.setup_plot()
        if running:
            self._start()

    def _start(self):
        """
        Start the projection optimization.
        """
        def update_freeviz(anchors):
            while True:
                projection = FreeViz.freeviz(
                    self._X, self._Y, scale=False, center=False,
                    initial=anchors, maxiter=10
                )
                yield projection[0], projection[1]
                if np.allclose(anchors, projection[1], rtol=1e-5, atol=1e-4):
                    return
                anchors = projection[1]

        self._loop.setCoroutine(update_freeviz(self.graph.get_points()))
        self.btn_start.setText("Stop")
        self.progressBarInit()
        self.setBlocking(True)
        self.setStatusMessage("Optimizing")

    def __set_projection(self, projection):
        # Set/update the projection matrix and coordinate embeddings
        self.progressBarAdvance(100. / MAX_ITERATIONS)
        self._embedding_coords = projection[0]
        self.graph.set_points(projection[1])
        self._update_xy()

    def __freeviz_finished(self):
        # Projection optimization has finished
        self.btn_start.setText("Optimize")
        self.setStatusMessage("")
        self.setBlocking(False)
        self.progressBarFinished()
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def _update_xy(self):
        coords = self._embedding_coords
        self._embedding_coords /= np.max(np.linalg.norm(coords, axis=1))
        self.graph.update_coordinates()

    def clear(self):
        self._loop.cancel()
        self.data = None
        self.valid_data = None
        self._embedding_coords = None
        self._X = None
        self._Y = None
        self._rand_indices = None

        self.graph.set_attributes(())
        self.graph.set_points([])
        self.graph.update_coordinates()
        self.graph.clear()

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.closeContext()
        self.clear()
        self.data = data
        self._check_data()
        self.init_attr_values()
        self.openContext(data)
        self.btn_start.setEnabled(self.data is not None)
        self.cb_class_density.setEnabled(self.can_draw_density())

    def _check_data(self):
        if self.data is not None:
            if self.data.is_sparse():
                self.Error.sparse_data()
                self.data = None
            elif self.data.domain.class_var is None:
                self.Error.no_class_var()
                self.data = None
            elif self.data.domain.class_var.is_discrete and \
                    len(self.data.domain.class_var.values) < 2:
                self.Error.not_enough_class_vars()
                self.data = None
            elif len(self.data.domain.attributes) > self.data.X.shape[0]:
                self.Error.features_exceeds_instances()
                self.data = None
            else:
                self._prepare_freeviz_data()
                if self._X is not None:
                    if len(self._X) > MAX_INSTANCES:
                        self.Error.too_many_data_instances()
                        self.data = None
                    elif np.allclose(np.nan_to_num(self._X - self._X[0]), 0) \
                            or not len(self._X):
                        self.Error.no_valid_data()
                        self.data = None
                else:
                    self.Error.no_valid_data()
                    self.data = None

    def _prepare_freeviz_data(self):
        valid_mask = np.all(np.isfinite(self.data.X), axis=1) & \
                     np.isfinite(self.data.Y)
        X, Y = self.data.X[valid_mask], self.data.Y[valid_mask]
        if not len(X):
            self.valid_data = None
            return

        if self.data.domain.class_var.is_discrete:
            Y = Y.astype(int)
        X = (X - np.mean(X, axis=0))
        span = np.ptp(X, axis=0)
        X[:, span > 0] /= span[span > 0].reshape(1, -1)
        self._X, self._Y, self.valid_data = X, Y, valid_mask

    @Inputs.data_subset
    def set_subset_data(self, subset):
        self.subset_data = subset
        self.subset_indices = {e.id for e in subset} \
            if subset is not None else {}
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if self.data is not None and self.valid_data is not None:
            self.setup_plot()
        self.commit()

    def get_coordinates_data(self):
        return (self._embedding_coords[:, 0], self._embedding_coords[:, 1]) \
            if self._embedding_coords is not None else (None, None)

    def get_subset_mask(self):
        if self.subset_indices:
            return np.array([ex.id in self.subset_indices
                             for ex in self.data[self.valid_data]])

    def setup_plot(self):
        points = FreeViz.init_radial(self._X.shape[1]) \
            if self.initialization == InitType.Circular \
            else FreeViz.init_random(self._X.shape[1], 2)
        self.graph.set_points(points)
        self.__set_embedding_coords()
        self.graph.set_attributes(self.data.domain.attributes)
        self.graph.reset_graph()

    def _randomize_indices(self):
        n = len(self._X)
        if n > MAX_POINTS:
            self._rand_indices = np.random.choice(n, MAX_POINTS, replace=False)
            self._rand_indices = sorted(self._rand_indices)

    def _manual_move(self):
        self.__set_embedding_coords()
        if self._rand_indices is not None:
            # save widget state
            selection = self.graph.selection
            valid_data = self.valid_data.copy()
            data = self.data.copy()
            ec = self._embedding_coords.copy()

            # plot subset
            self.__plot_random_subset(selection)

            # restore widget state
            self.graph.selection = selection
            self.valid_data = valid_data
            self.data = data
            self._embedding_coords = ec
        else:
            self.graph.update_coordinates()

    def __plot_random_subset(self, selection):
        self._embedding_coords = self._embedding_coords[self._rand_indices]
        self.data = self.data[self._rand_indices]
        self.valid_data = self.valid_data[self._rand_indices]
        self.graph.reset_graph()
        if selection is not None:
            self.graph.selection = selection[self._rand_indices]
            self.graph.update_selection_colors()

    def _finish_manual_move(self):
        if self._rand_indices is not None:
            selection = self.graph.selection
            self.graph.reset_graph()
            if selection is not None:
                self.graph.selection = selection
                self.graph.select_by_index(self.graph.get_selection())

    def __set_embedding_coords(self):
        points = self.graph.get_points()
        ex = np.dot(self._X, points)
        self._embedding_coords = (ex / np.max(np.linalg.norm(ex, axis=1)))

    def selection_changed(self):
        self.commit()

    def commit(self):
        selected = annotated = components = None
        if self.data is not None and self.valid_data is not None:
            name = self.data.name
            domain = self.data.domain
            metas = domain.metas + (self.variable_x, self.variable_y)
            domain = Domain(domain.attributes, domain.class_vars, metas)
            embedding_coords = np.zeros((len(self.data), 2), dtype=np.float)
            embedding_coords[self.valid_data] = self._embedding_coords

            data = self.data.transform(domain)
            data[:, self.variable_x] = embedding_coords[:, 0][:, None]
            data[:, self.variable_y] = embedding_coords[:, 1][:, None]

            selection = self.graph.get_selection()
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if self.graph.selection is not None and \
                    np.max(self.graph.selection) > 1:
                annotated = create_groups_table(data, self.graph.selection)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            comp_domain = Domain(
                self.data.domain.attributes,
                metas=[StringVariable(name='component')])

            metas = np.array([["FreeViz 1"], ["FreeViz 2"]])
            components = Table.from_numpy(
                comp_domain,
                X=self.graph.get_points().T,
                metas=metas)

            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert((
            ("Color", name(self.attr_color)),
            ("Label", name(self.attr_label)),
            ("Shape", name(self.attr_shape)),
            ("Size", name(self.attr_size)),
            ("Jittering", self.graph.jitter_size != 0 and
             "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            if "radius" in _settings:
                _settings["graph"]["radius"] = _settings["radius"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Exemple #5
0
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results", Results)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")
        wrong_targets = Msg(
            "Some model(s) predict a different target (see more ...)\n{}")

    class Error(OWWidget.Error):
        predictor_failed = Msg("Some predictor(s) failed (see more ...)\n{}")
        scorer_failed = Msg("Some scorer(s) failed (see more ...)\n{}")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: List of selected class value indices in the `class_values` list
    selected_classes = settings.ContextSetting([])
    selection = settings.Setting([], schema_only=True)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []
        self.left_width = 10
        self.selection_store = None
        self.__pending_selection = self.selection

        self._set_input_summary()
        self._set_output_summary(None)

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.ExtendedSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        self.reset_button = gui.button(
            self.controlArea,
            self,
            "Restore Original Order",
            callback=self._reset_order,
            tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.ExtendedSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(sortingEnabled=True,
                                  verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.dataview.setItemDelegate(DataItemDelegate(self.dataview))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.splitterMoved.connect(self.splitter_resized)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)

    def get_selection_store(self, proxy):
        # Both proxies map the same, so it doesn't matter which one is used
        # to initialize SharedSelectionStore
        if self.selection_store is None:
            self.selection_store = SharedSelectionStore(proxy)
        return self.selection_store

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.Warning.empty_data(shown=data is not None and not data)
        self.data = data
        self.selection_store = None
        if not data:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = SortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            sel_model = SharedSelectionModel(
                self.get_selection_store(modelproxy), modelproxy,
                self.dataview)
            self.dataview.setSelectionModel(sel_model)
            if self.__pending_selection is not None:
                self.selection = self.__pending_selection
                self.__pending_selection = None
                self.selection_store.select_rows(
                    set(self.selection), QItemSelectionModel.ClearAndSelect)
            sel_model.selectionChanged.connect(self.commit)
            sel_model.selectionChanged.connect(self._store_selection)

            self.dataview.model().list_sorted.connect(
                partial(self._update_data_sort_order, self.dataview,
                        self.predictionsview))

        self._invalidate_predictions()

    def _store_selection(self):
        self.selection = list(self.selection_store.rows)

    @property
    def class_var(self):
        return self.data and self.data.domain.class_var

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = PredictorSlot(predictor, predictor.name,
                                                None)

    def _set_class_values(self):
        class_values = []
        for slot in self.predictors.values():
            class_var = slot.predictor.domain.class_var
            if class_var and class_var.is_discrete:
                for value in class_var.values:
                    if value not in class_values:
                        class_values.append(value)

        if self.class_var and self.class_var.is_discrete:
            values = self.class_var.values
            self.class_values = sorted(class_values,
                                       key=lambda val: val not in values)
            self.selected_classes = [
                i for i, name in enumerate(class_values) if name in values
            ]
        else:
            self.class_values = class_values  # This assignment updates listview
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_values()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._set_input_summary()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        if self.class_var:
            domain = self.data.domain
            classless_data = self.data.transform(
                Domain(domain.attributes, None, domain.metas))
        else:
            classless_data = self.data

        for inputid, slot in self.predictors.items():
            if isinstance(slot.results, Results):
                continue

            predictor = slot.predictor
            try:
                if predictor.domain.class_var.is_discrete:
                    pred, prob = predictor(classless_data, Model.ValueProbs)
                else:
                    pred = predictor(classless_data, Model.Value)
                    prob = numpy.zeros((len(pred), 0))
            except (ValueError, DomainTransformationError) as err:
                self.predictors[inputid] = \
                    slot._replace(results=f"{predictor.name}: {err}")
                continue

            results = Results()
            results.data = self.data
            results.domain = self.data.domain
            results.row_indices = numpy.arange(len(self.data))
            results.folds = (Ellipsis, )
            results.actual = self.data.Y
            results.unmapped_probabilities = prob
            results.unmapped_predicted = pred
            results.probabilities = results.predicted = None
            self.predictors[inputid] = slot._replace(results=results)

            target = predictor.domain.class_var
            if target != self.class_var:
                continue

            if target is not self.class_var and target.is_discrete:
                backmappers, n_values = predictor.get_backmappers(self.data)
                prob = predictor.backmap_probs(prob, n_values, backmappers)
                pred = predictor.backmap_value(pred, prob, n_values,
                                               backmappers)
            results.predicted = pred.reshape((1, len(self.data)))
            results.probabilities = prob.reshape((1, ) + prob.shape)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()
        scorers = usable_scorers(self.class_var) if self.class_var else []
        self.score_table.update_header(scorers)
        errors = []
        for inputid, pred in self.predictors.items():
            results = self.predictors[inputid].results
            if not isinstance(results, Results) or results.predicted is None:
                continue
            row = [
                QStandardItem(learner_name(pred.predictor)),
                QStandardItem("N/A"),
                QStandardItem("N/A")
            ]
            for scorer in scorers:
                item = QStandardItem()
                try:
                    score = scorer_caller(scorer, results)()[0]
                    item.setText(f"{score:.3f}")
                except Exception as exc:  # pylint: disable=broad-except
                    item.setToolTip(str(exc))
                    if scorer.name in self.score_table.shown_scores:
                        errors.append(str(exc))
                row.append(item)
            self.score_table.model.appendRow(row)

        view = self.score_table.view
        if model.rowCount():
            view.setVisible(True)
            view.ensurePolished()
            view.setFixedHeight(5 + view.horizontalHeader().height() +
                                view.verticalHeader().sectionSize(0) *
                                model.rowCount())
        else:
            view.setVisible(False)

        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(f"- {p.predictor.name}: {p.results}"
                           for p in self.predictors.values()
                           if isinstance(p.results, str) and p.results)
        self.Error.predictor_failed(errors, shown=bool(errors))

        if self.class_var:
            inv_targets = "\n".join(
                f"- {pred.name} predicts '{pred.domain.class_var.name}'"
                for pred in (p.predictor for p in self.predictors.values()
                             if isinstance(p.results, Results)
                             and p.results.probabilities is None))
            self.Warning.wrong_targets(inv_targets, shown=bool(inv_targets))
        else:
            self.Warning.wrong_targets.clear()

    def _set_input_summary(self):
        if not self.data and not self.predictors:
            self.info.set_input_summary(self.info.NoInput)
            return

        summary = len(self.data) if self.data else 0
        details = self._get_details()
        self.info.set_input_summary(summary, details, format=Qt.RichText)

    def _get_details(self):
        details = "Data:<br>"
        details += format_summary_details(self.data).replace('\n', '<br>') if \
            self.data else "No data on input."
        details += "<hr>"
        pred_names = [v.name for v in self.predictors.values()]
        n_predictors = len(self.predictors)
        if n_predictors:
            n_valid = len(self._non_errored_predictors())
            details += plural("Model: {number} model{s}", n_predictors)
            if n_valid != n_predictors:
                details += f" ({n_predictors - n_valid} failed)"
            details += "<ul>"
            for name in pred_names:
                details += f"<li>{name}</li>"
            details += "</ul>"
        else:
            details += "Model:<br>No model on input."
        return details

    def _set_output_summary(self, output):
        summary = len(output) if output else self.info.NoOutput
        details = format_summary_details(output) if output else ""
        self.info.set_output_summary(summary, details)

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _non_errored_predictors(self):
        return [
            p for p in self.predictors.values()
            if isinstance(p.results, Results)
        ]

    def _reordered_probabilities(self, prediction):
        cur_values = prediction.predictor.domain.class_var.values
        new_ind = [self.class_values.index(x) for x in cur_values]
        probs = prediction.results.unmapped_probabilities
        new_probs = numpy.full((probs.shape[0], len(self.class_values)),
                               numpy.nan)
        new_probs[:, new_ind] = probs
        return new_probs

    def _update_predictions_model(self):
        results = []
        headers = []
        for p in self._non_errored_predictors():
            values = p.results.unmapped_predicted
            target = p.predictor.domain.class_var
            if target.is_discrete:
                # order probabilities in order from Show prob. for
                prob = self._reordered_probabilities(p)
                values = [Value(target, v) for v in values]
            else:
                prob = numpy.zeros((len(values), 0))
            results.append((values, prob))
            headers.append(p.predictor.name)

        if results:
            results = list(zip(*(zip(*res) for res in results)))
            model = PredictionsModel(results, headers)
        else:
            model = None

        if self.selection_store is not None:
            self.selection_store.unregister(
                self.predictionsview.selectionModel())

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setModel(predmodel)

        self.predictionsview.setSelectionModel(
            SharedSelectionModel(self.get_selection_store(predmodel),
                                 predmodel, self.predictionsview))

        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        self.predictionsview.model().list_sorted.connect(
            partial(self._update_data_sort_order, self.predictionsview,
                    self.dataview))

        self.predictionsview.resizeColumnsToContents()

    def _update_data_sort_order(self, sort_source_view, sort_dest_view):
        sort_dest = sort_dest_view.model()
        sort_source = sort_source_view.model()
        sortindicatorshown = False
        if sort_dest is not None:
            assert isinstance(sort_dest, QSortFilterProxyModel)
            n = sort_dest.rowCount()
            if sort_source is not None and sort_source.sortColumn() >= 0:
                sortind = numpy.argsort([
                    sort_source.mapToSource(sort_source.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            sort_dest.setSortIndices(sortind)

        sort_dest_view.horizontalHeader().setSortIndicatorShown(False)
        sort_source_view.horizontalHeader().setSortIndicatorShown(
            sortindicatorshown)
        self.commit()

    def _reset_order(self):
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.setSortIndices(None)
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.setSortIndices(None)
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)
        self.dataview.horizontalHeader().setSortIndicatorShown(False)

    def _all_color_values(self):
        """
        Return list of colors together with their values from all predictors
        classes. Colors and values are sorted according to the values order
        for simpler comparison.
        """
        predictors = self._non_errored_predictors()
        color_values = [
            list(
                zip(*sorted(zip(p.predictor.domain.class_var.colors,
                                p.predictor.domain.class_var.values),
                            key=itemgetter(1)))) for p in predictors
            if p.predictor.domain.class_var.is_discrete
        ]
        return color_values if color_values else [([], [])]

    @staticmethod
    def _colors_match(colors1, values1, color2, values2):
        """
        Test whether colors for values match. Colors matches when all
        values match for shorter list and colors match for shorter list.
        It is assumed that values will be sorted together with their colors.
        """
        shorter_length = min(len(colors1), len(color2))
        return (values1[:shorter_length] == values2[:shorter_length]
                and (numpy.array(colors1[:shorter_length]) == numpy.array(
                    color2[:shorter_length])).all())

    def _get_colors(self):
        """
        Defines colors for values. If colors match in all models use the union
        otherwise use standard colors.
        """
        all_colors_values = self._all_color_values()
        base_color, base_values = all_colors_values[0]
        for c, v in all_colors_values[1:]:
            if not self._colors_match(base_color, base_values, c, v):
                base_color = []
                break
            # replace base_color if longer
            if len(v) > len(base_color):
                base_color = c
                base_values = v

        if len(base_color) != len(self.class_values):
            return LimitedDiscretePalette(len(self.class_values)).palette
        # reorder colors to widgets order
        colors = [None] * len(self.class_values)
        for c, v in zip(base_color, base_values):
            colors[self.class_values.index(v)] = c
        return colors

    def _update_prediction_delegate(self):
        self._delegates.clear()
        colors = self._get_colors()
        for col, slot in enumerate(self.predictors.values()):
            target = slot.predictor.domain.class_var
            shown_probs = (() if target.is_continuous else [
                val if self.class_values[val] in target.values else None
                for val in self.selected_classes
            ])
            delegate = PredictionsItemDelegate(
                None if target.is_continuous else self.class_values,
                colors,
                shown_probs,
                target.format_str if target.is_continuous else None,
                parent=self.predictionsview)
            # QAbstractItemView does not take ownership of delegates, so we must
            self._delegates.append(delegate)
            self.predictionsview.setItemDelegateForColumn(col, delegate)
            self.predictionsview.setColumnHidden(col, False)

        self.predictionsview.resizeColumnsToContents()
        self._recompute_splitter_sizes()
        if self.predictionsview.model() is not None:
            self.predictionsview.model().setProbInd(self.selected_classes)

    def _recompute_splitter_sizes(self):
        if not self.data:
            return
        view = self.predictionsview
        self.left_width = \
            view.horizontalHeader().length() + view.verticalHeader().width()
        self._update_splitter()

    def _update_splitter(self):
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([self.left_width, w1 + w2 - self.left_width])

    def splitter_resized(self):
        self.left_width = self.splitter.sizes()[0]

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = [
            p for p in self._non_errored_predictors()
            if p.results.predicted is not None
        ]
        if not slots:
            self.Outputs.evaluation_results.send(None)
            return

        nanmask = numpy.isnan(self.data.get_column_view(self.class_var)[0])
        data = self.data[~nanmask]
        results = Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if self.class_var and self.class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        if not self.data:
            self._set_output_summary(None)
            self.Outputs.predictions.send(None)
            return

        newmetas = []
        newcolumns = []
        for slot in self._non_errored_predictors():
            if slot.predictor.domain.class_var.is_discrete:
                self._add_classification_out_columns(slot, newmetas,
                                                     newcolumns)
            else:
                self._add_regression_out_columns(slot, newmetas, newcolumns)

        attrs = list(self.data.domain.attributes)
        metas = list(self.data.domain.metas)
        names = [
            var.name
            for var in chain(attrs, self.data.domain.class_vars, metas) if var
        ]
        uniq_newmetas = []
        for new_ in newmetas:
            uniq = get_unique_names(names, new_.name)
            if uniq != new_.name:
                new_ = new_.copy(name=uniq)
            uniq_newmetas.append(new_)
            names.append(uniq)

        metas += uniq_newmetas
        domain = Orange.data.Domain(attrs, self.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns

        index = self.dataview.model().index
        map_to = self.dataview.model().mapToSource
        assert self.selection_store is not None
        rows = None
        if self.selection_store.rows:
            rows = [
                ind.row()
                for ind in self.dataview.selectionModel().selectedRows(0)
            ]
            rows.sort()
        elif self.dataview.model().isSorted() \
                or self.predictionsview.model().isSorted():
            rows = list(range(len(self.data)))
        if rows:
            source_rows = [map_to(index(row, 0)).row() for row in rows]
            predictions = predictions[source_rows]
        self.Outputs.predictions.send(predictions)
        self._set_output_summary(predictions)

    @staticmethod
    def _add_classification_out_columns(slot, newmetas, newcolumns):
        # Mapped or unmapped predictions?!
        # Or provide a checkbox so the user decides?
        pred = slot.predictor
        name = pred.name
        values = pred.domain.class_var.values
        newmetas.append(DiscreteVariable(name=name, values=values))
        newcolumns.append(slot.results.unmapped_predicted.reshape(-1, 1))
        newmetas += [
            ContinuousVariable(name=f"{name} ({value})") for value in values
        ]
        newcolumns.append(slot.results.unmapped_probabilities)

    @staticmethod
    def _add_regression_out_columns(slot, newmetas, newcolumns):
        newmetas.append(ContinuousVariable(name=slot.predictor.name))
        newcolumns.append(slot.results.unmapped_predicted.reshape((-1, 1)))

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_view = self.predictionsview
            predictions_model = predictions_view.model()

            # use ItemDelegate to style prediction values
            delegates = [
                predictions_view.itemDelegateForColumn(i)
                for i in range(predictions_model.columnCount())
            ]

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [delegate.displayText(
                          predictions_model.data(predictions_model.index(i, j)),
                          QLocale())
                       for j, delegate in enumerate(delegates)] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data:
            text = self._get_details().replace('\n', '<br>')
            if self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)

            self.report_table("Scores", self.score_table.view)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._update_splitter()

    def showEvent(self, event):
        super().showEvent(event)
        QTimer.singleShot(0, self._update_splitter)
Exemple #6
0
class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin):
    MAX_INSTANCES = 10000

    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240
    keywords = ["viz"]

    settings_version = 3
    initialization = settings.Setting(InitType.Circular)
    GRAPH_CLASS = OWFreeVizGraph
    graph = settings.SettingProvider(OWFreeVizGraph)

    class Error(OWAnchorProjectionWidget.Error):
        no_class_var = widget.Msg("Data must have a target variable.")
        multiple_class_vars = widget.Msg(
            "Data must have a single target variable.")
        not_enough_class_vars = widget.Msg(
            "Target variable must have at least two unique values.")
        features_exceeds_instances = widget.Msg(
            "Number of features exceeds the number of instances.")
        too_many_data_instances = widget.Msg("Data is too large.")
        constant_data = widget.Msg("All data columns are constant.")
        not_enough_features = widget.Msg("At least two features are required.")

    class Warning(OWAnchorProjectionWidget.Warning):
        removed_features = widget.Msg("Categorical features with more than"
                                      " two values are not shown.")

    def __init__(self):
        OWAnchorProjectionWidget.__init__(self)
        ConcurrentWidgetMixin.__init__(self)

    def _add_controls(self):
        self.__add_controls_start_box()
        super()._add_controls()
        self.gui.add_control(self._effects_box,
                             gui.hSlider,
                             "Hide radius:",
                             master=self.graph,
                             value="hide_radius",
                             minValue=0,
                             maxValue=100,
                             step=10,
                             createLabel=False,
                             callback=self.__radius_slider_changed)

    def __add_controls_start_box(self):
        box = gui.vBox(self.controlArea, box="Optimize", spacing=0)
        gui.comboBox(box,
                     self,
                     "initialization",
                     label="Initialization:",
                     items=InitType.items(),
                     orientation=Qt.Horizontal,
                     labelWidth=90,
                     callback=self.__init_combo_changed)
        self.run_button = gui.button(box, self, "Start", self._toggle_run)

    @property
    def effective_variables(self):
        return [
            a for a in self.data.domain.attributes
            if a.is_continuous or a.is_discrete and len(a.values) == 2
        ]

    def __radius_slider_changed(self):
        self.graph.update_radius()

    def __init_combo_changed(self):
        self.Error.proj_error.clear()
        self.init_projection()
        self.setup_plot()
        self.commit.deferred()
        if self.task is not None:
            self._run()

    def _toggle_run(self):
        if self.task is not None:
            self.cancel()
            self.graph.set_sample_size(None)
            self.run_button.setText("Resume")
            self.commit.deferred()
        else:
            self._run()

    def _run(self):
        if self.data is None:
            return
        self.graph.set_sample_size(self.SAMPLE_SIZE)
        self.run_button.setText("Stop")
        self.start(run_freeviz, self.effective_data, self.projector)

    # ConcurrentWidgetMixin
    def on_partial_result(self, result: Result):
        assert isinstance(result.projector, FreeViz)
        assert isinstance(result.projection, FreeVizModel)
        self.projector = result.projector
        self.projection = result.projection
        self.graph.update_coordinates()
        self.graph.update_density()

    def on_done(self, result: Result):
        assert isinstance(result.projector, FreeViz)
        assert isinstance(result.projection, FreeVizModel)
        self.projector = result.projector
        self.projection = result.projection
        self.graph.set_sample_size(None)
        self.run_button.setText("Start")
        self.commit.deferred()

    def on_exception(self, ex: Exception):
        self.Error.proj_error(ex)
        self.graph.set_sample_size(None)
        self.run_button.setText("Start")

    # OWAnchorProjectionWidget
    def set_data(self, data):
        super().set_data(data)
        self.graph.set_sample_size(None)
        if self._invalidated:
            self.init_projection()

    def init_projection(self):
        if self.data is None:
            return
        anchors = FreeViz.init_radial(len(self.effective_variables)) \
            if self.initialization == InitType.Circular \
            else FreeViz.init_random(len(self.effective_variables), 2)
        self.projector = FreeViz(scale=False,
                                 center=False,
                                 initial=anchors,
                                 maxiter=10)
        data = self.projector.preprocess(self.effective_data)
        self.projector.domain = data.domain
        self.projector.components_ = anchors.T
        self.projection = FreeVizModel(self.projector, self.projector.domain,
                                       2)
        self.projection.pre_domain = data.domain
        self.projection.name = self.projector.name

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            class_vars, domain = self.data.domain.class_vars, self.data.domain
            if not class_vars:
                error(self.Error.no_class_var)
            elif len(class_vars) > 1:
                error(self.Error.multiple_class_vars)
            elif class_vars[0].is_discrete and len(np.unique(self.data.Y)) < 2:
                error(self.Error.not_enough_class_vars)
            elif len(self.data.domain.attributes) < 2:
                error(self.Error.not_enough_features)
            elif len(self.data.domain.attributes) > self.data.X.shape[0]:
                error(self.Error.features_exceeds_instances)
            elif not np.sum(np.std(self.data.X, axis=0)):
                error(self.Error.constant_data)
            elif np.sum(np.all(np.isfinite(self.data.X),
                               axis=1)) > self.MAX_INSTANCES:
                error(self.Error.too_many_data_instances)
            else:
                if len(self.effective_variables) < len(domain.attributes):
                    self.Warning.removed_features()

    def enable_controls(self):
        super().enable_controls()
        self.run_button.setEnabled(self.data is not None)
        self.run_button.setText("Start")

    def get_coordinates_data(self):
        embedding = self.get_embedding()
        if embedding is None:
            return None, None
        valid_emb = embedding[self.valid_data]
        return valid_emb.T / (np.max(np.linalg.norm(valid_emb, axis=1)) or 1)

    def _manual_move(self, anchor_idx, x, y):
        self.projector.initial[anchor_idx] = [x, y]
        super()._manual_move(anchor_idx, x, y)

    def clear(self):
        super().clear()
        self.cancel()

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            if "radius" in _settings:
                _settings["graph"]["hide_radius"] = _settings["radius"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
class OWTopicModeling(OWWidget):
    name = "Topic Modelling"
    description = "Uncover the hidden thematic structure in a corpus."
    icon = "icons/TopicModeling.svg"
    priority = 400

    settingsHandler = DomainContextHandler()

    # Input/output
    class Inputs:
        corpus = Input("Corpus", Corpus)

    class Outputs:
        corpus = Output("Corpus", Table)
        selected_topic = Output("Selected Topic", Topic)
        all_topics = Output("All Topics", Table)

    want_main_area = True

    methods = [
        (LsiWidget, 'lsi'),
        (LdaWidget, 'lda'),
        (HdpWidget, 'hdp'),
    ]

    # Settings
    autocommit = settings.Setting(True)
    method_index = settings.Setting(0)

    lsi = settings.SettingProvider(LsiWidget)
    hdp = settings.SettingProvider(HdpWidget)
    lda = settings.SettingProvider(LdaWidget)

    control_area_width = 300

    def __init__(self):
        super().__init__()
        self.corpus = None
        self.learning_thread = None

        # Commit button
        gui.auto_commit(self.buttonsArea,
                        self,
                        'autocommit',
                        'Commit',
                        box=False)

        button_group = QButtonGroup(self, exclusive=True)
        button_group.buttonClicked[int].connect(self.change_method)

        self.widgets = []
        method_layout = QVBoxLayout()
        self.controlArea.layout().addLayout(method_layout)
        for i, (method, attr_name) in enumerate(self.methods):
            widget = method(self, title='Options')
            widget.setFixedWidth(self.control_area_width)
            widget.valueChanged.connect(self.commit)
            self.widgets.append(widget)
            setattr(self, attr_name, widget)

            rb = QRadioButton(text=widget.Model.name)
            button_group.addButton(rb, i)
            method_layout.addWidget(rb)
            method_layout.addWidget(widget)

        button_group.button(self.method_index).setChecked(True)
        self.toggle_widgets()
        method_layout.addStretch()

        # Topics description
        self.topic_desc = TopicViewer()
        self.topic_desc.topicSelected.connect(self.send_topic_by_id)
        self.mainArea.layout().addWidget(self.topic_desc)
        self.topic_desc.setFocus()

    @Inputs.corpus
    def set_data(self, data=None):
        self.corpus = data
        self.apply()

    def commit(self):
        if self.corpus is not None:
            self.apply()

    @property
    def model(self):
        return self.widgets[self.method_index].model

    def change_method(self, new_index):
        if self.method_index != new_index:
            self.method_index = new_index
            self.toggle_widgets()
            self.commit()

    def toggle_widgets(self):
        for i, widget in enumerate(self.widgets):
            widget.setVisible(i == self.method_index)

    def apply(self):
        self.learning_task.stop()
        if self.corpus is not None:
            self.learning_task()
        else:
            self.on_result(None)

    @asynchronous
    def learning_task(self):
        return self.model.fit_transform(self.corpus.copy(),
                                        chunk_number=100,
                                        on_progress=self.on_progress)

    @learning_task.on_start
    def on_start(self):
        self.progressBarInit(None)
        self.topic_desc.clear()

    @learning_task.on_result
    def on_result(self, corpus):
        self.progressBarFinished(None)
        self.Outputs.corpus.send(corpus)
        if corpus is None:
            self.topic_desc.clear()
            self.Outputs.selected_topic.send(None)
            self.Outputs.all_topics.send(None)
        else:
            self.topic_desc.show_model(self.model)
            self.Outputs.all_topics.send(self.model.get_all_topics_table())

    @learning_task.callback
    def on_progress(self, p):
        self.progressBarSet(100 * p, processEvents=None)

    def send_report(self):
        self.report_items(*self.widgets[self.method_index].report_model())
        if self.corpus is not None:
            self.report_items('Topics', self.topic_desc.report())

    def send_topic_by_id(self, topic_id=None):
        if self.model.model and topic_id is not None:
            self.Outputs.selected_topic.send(
                self.model.get_topics_table_by_id(topic_id))
Exemple #8
0
class OWRadviz(OWProjectionWidget):
    name = "Radviz"
    description = "Display Radviz projection"
    icon = "icons/Radviz.svg"
    priority = 241
    keywords = ["viz"]

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    settings_version = 2
    settingsHandler = settings.DomainContextHandler()

    variable_state = settings.ContextSetting({})
    auto_commit = settings.Setting(True)

    vizrank = settings.SettingProvider(RadvizVizRank)
    graph = settings.SettingProvider(OWRadvizGraph)
    graph_name = "graph.plot_widget.plotItem"

    ReplotRequest = QEvent.registerEventType()

    class Information(OWProjectionWidget.Information):
        sql_sampled_data = widget.Msg("Data has been sampled")

    class Warning(OWProjectionWidget.Warning):
        no_features = widget.Msg("At least 2 features have to be chosen")
        invalid_embedding = widget.Msg("No projection for selected features")

    class Error(OWProjectionWidget.Error):
        sparse_data = widget.Msg("Sparse data is not supported")
        no_features = widget.Msg(
            "At least 3 numeric or categorical variables are required")
        no_instances = widget.Msg("At least 2 data instances are required")

    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self.subset_indices = None
        self._embedding_coords = None
        self._rand_indices = None

        self.__replot_requested = False

        self.variable_x = ContinuousVariable("radviz-x")
        self.variable_y = ContinuousVariable("radviz-y")

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWRadvizGraph(self, box)
        box.layout().addWidget(self.graph.plot_widget)

        self.variables_selection = VariablesSelection()
        self.model_selected = selected = VariableListModel(enable_dnd=True)
        self.model_other = other = VariableListModel(enable_dnd=True)
        self.variables_selection(self, selected, other, self.controlArea)

        self.vizrank, self.btn_vizrank = RadvizVizRank.add_vizrank(
            None, self, "Suggest features", self.vizrank_set_attrs)
        # Todo: this button introduces some margin at the bottom?!
        self.variables_selection.add_remove.layout().addWidget(
            self.btn_vizrank)

        g = self.graph.gui
        g.point_properties_box(self.controlArea)
        g.effects_box(self.controlArea)
        g.plot_properties_box(self.controlArea)

        self.graph.box_zoom_select(self.controlArea)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

        self.graph.view_box.started.connect(self._randomize_indices)
        self.graph.view_box.moved.connect(self._manual_move)
        self.graph.view_box.finished.connect(self._finish_manual_move)

    def vizrank_set_attrs(self, attrs):
        if not attrs:
            return
        self.variables_selection.display_none()
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [v for v in self.model_other if v not in attrs]

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.can_draw_density())

    def invalidate_plot(self):
        """
        Schedule a delayed replot.
        """
        if not self.__replot_requested:
            self.__replot_requested = True
            QApplication.postEvent(self, QEvent(self.ReplotRequest),
                                   Qt.LowEventPriority - 10)

    def _vizrank_color_change(self):
        is_enabled = self.data is not None and not self.data.is_sparse() and \
            len(self.model_other) + len(self.model_selected) > 3 and \
            len(self.data[self.valid_data]) > 1 and \
            np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0)
        self.btn_vizrank.setEnabled(
            is_enabled and self.attr_color is not None and not np.isnan(
                self.data.get_column_view(
                    self.attr_color)[0].astype(float)).all())
        self.vizrank.initialize()

    def clear(self):
        self.data = None
        self.valid_data = None
        self._embedding_coords = None
        self._rand_indices = None
        self.model_selected.clear()
        self.model_other.clear()

        self.graph.set_attributes(())
        self.graph.set_points(None)
        self.graph.update_coordinates()
        self.graph.clear()

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self.btn_vizrank.setEnabled(False)
        self.closeContext()
        self.clear()
        self.data = data
        self._check_data()
        self.init_attr_values()
        self.openContext(self.data)
        if self.data is not None:
            self.model_selected[:], self.model_other[:] = self._load_settings()

    def _check_data(self):
        if self.data is not None:
            domain = self.data.domain
            if self.data.is_sparse():
                self.Error.sparse_data()
                self.data = None
            elif isinstance(self.data, SqlTable):
                if self.data.approx_len() < 4000:
                    self.data = Table(self.data)
                else:
                    self.Information.sql_sampled_data()
                    data_sample = self.data.sample_time(1, no_cache=True)
                    data_sample.download_data(2000, partial=True)
                    self.data = Table(data_sample)
            elif len(self.data) < 2:
                self.Error.no_instances()
                self.data = None
            elif len([
                    v for v in domain.variables + domain.metas
                    if v.is_primitive()
            ]) < 3:
                self.Error.no_features()
                self.data = None

    def _load_settings(self):
        domain = self.data.domain
        variables = [
            v for v in domain.attributes + domain.metas if v.is_primitive()
        ]
        self.model_selected[:] = variables[:5]
        self.model_other[:] = variables[5:] + list(domain.class_vars)

        state = VariablesSelection.encode_var_state(
            [list(self.model_selected),
             list(self.model_other)])
        state = {key: (ind, np.inf) for key, (ind, _) in state.items()}
        state.update(self.variable_state)
        return VariablesSelection.decode_var_state(
            state, [list(self.model_selected),
                    list(self.model_other)])

    @Inputs.data_subset
    def set_subset_data(self, subset):
        self.subset_data = subset
        self.subset_indices = {e.id for e in subset} \
            if subset is not None else {}
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        self.setup_plot()
        self._vizrank_color_change()
        self.commit()

    def get_coordinates_data(self):
        ec = self._embedding_coords
        if ec is None or np.any(np.isnan(ec)):
            return None, None
        return ec[:, 0], ec[:, 1]

    def get_subset_mask(self):
        if self.subset_indices:
            return np.array([
                ex.id in self.subset_indices
                for ex in self.data[self.valid_data]
            ])

    def customEvent(self, event):
        if event.type() == OWRadviz.ReplotRequest:
            self.__replot_requested = False
            self.setup_plot()
        else:
            super().customEvent(event)

    def closeContext(self):
        self.variable_state = VariablesSelection.encode_var_state(
            [list(self.model_selected),
             list(self.model_other)])
        super().closeContext()

    def setup_plot(self):
        if self.data is None:
            return
        self.__replot_requested = False

        self.clear_messages()
        if len(self.model_selected) < 2:
            self.Warning.no_features()
            self.graph.clear()
            return

        r = radviz(self.data, self.model_selected)
        self._embedding_coords = r[0]
        self.graph.set_points(r[1])
        self.valid_data = r[2]
        if self._embedding_coords is None or \
                np.any(np.isnan(self._embedding_coords)):
            self.Warning.invalid_embedding()
        self.graph.reset_graph()

    def _randomize_indices(self):
        n = len(self._embedding_coords)
        if n > MAX_POINTS:
            self._rand_indices = np.random.choice(n, MAX_POINTS, replace=False)
            self._rand_indices = sorted(self._rand_indices)

    def _manual_move(self):
        self.__replot_requested = False

        res = radviz(self.data, self.model_selected, self.graph.get_points())
        self._embedding_coords = res[0]
        if self._rand_indices is not None:
            # save widget state
            selection = self.graph.selection
            valid_data = self.valid_data.copy()
            data = self.data.copy()
            ec = self._embedding_coords.copy()

            # plot subset
            self.__plot_random_subset(selection)

            # restore widget state
            self.graph.selection = selection
            self.valid_data = valid_data
            self.data = data
            self._embedding_coords = ec
        else:
            self.graph.update_coordinates()

    def __plot_random_subset(self, selection):
        self._embedding_coords = self._embedding_coords[self._rand_indices]
        self.data = self.data[self._rand_indices]
        self.valid_data = self.valid_data[self._rand_indices]
        self.graph.reset_graph()
        if selection is not None:
            self.graph.selection = selection[self._rand_indices]
            self.graph.update_selection_colors()

    def _finish_manual_move(self):
        if self._rand_indices is not None:
            selection = self.graph.selection
            self.graph.reset_graph()
            if selection is not None:
                self.graph.selection = selection
                self.graph.select_by_index(self.graph.get_selection())

    def selection_changed(self):
        self.commit()

    def commit(self):
        selected = annotated = components = None
        if self.data is not None and np.sum(self.valid_data):
            name = self.data.name
            domain = self.data.domain
            metas = domain.metas + (self.variable_x, self.variable_y)
            domain = Domain(domain.attributes, domain.class_vars, metas)
            embedding_coords = np.zeros((len(self.data), 2), dtype=np.float)
            embedding_coords[self.valid_data] = self._embedding_coords

            data = self.data.transform(domain)
            data[:, self.variable_x] = embedding_coords[:, 0][:, None]
            data[:, self.variable_y] = embedding_coords[:, 1][:, None]

            selection = self.graph.get_selection()
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if self.graph.selection is not None and \
                    np.max(self.graph.selection) > 1:
                annotated = create_groups_table(data, self.graph.selection)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            points = self.graph.get_points()
            comp_domain = Domain(points[:, 2],
                                 metas=[StringVariable(name='component')])

            metas = np.array([["RX"], ["RY"], ["angle"]])
            angle = np.arctan2(np.array(points[:, 1].T, dtype=float),
                               np.array(points[:, 0].T, dtype=float))
            components = Table.from_numpy(comp_domain,
                                          X=np.row_stack(
                                              (points[:, :2].T, angle)),
                                          metas=metas)
            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.attr_color)), ("Label",
                                                name(self.attr_label)),
             ("Shape", name(self.attr_shape)), ("Size", name(self.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
class OWRadviz(widget.OWWidget):
    name = "Radviz"
    description = "Radviz"

    icon = "icons/Radviz.svg"
    priority = 240

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    settings_version = 1
    settingsHandler = settings.DomainContextHandler()

    variable_state = settings.ContextSetting({})

    auto_commit = settings.Setting(True)
    graph = settings.SettingProvider(OWRadvizGraph)
    vizrank = settings.SettingProvider(RadvizVizRank)

    jitter_sizes = [0, 0.1, 0.5, 1.0, 2.0]

    ReplotRequest = QEvent.registerEventType()

    graph_name = "graph.plot_widget.plotItem"

    class Information(widget.OWWidget.Information):
        sql_sampled_data = widget.Msg("Data has been sampled")

    class Warning(widget.OWWidget.Warning):
        no_features = widget.Msg("At least 2 features have to be chosen")

    class Error(widget.OWWidget.Error):
        sparse_data = widget.Msg("Sparse data is not supported")
        no_features = widget.Msg(
            "At least 3 numeric or categorical variables are required"
        )
        no_instances = widget.Msg("At least 2 data instances are required")

    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self._subset_mask = None
        self._selection = None  # np.array
        self.__replot_requested = False
        self._new_plotdata()

        self.variable_x = ContinuousVariable("radviz-x")
        self.variable_y = ContinuousVariable("radviz-y")

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWRadvizGraph(self, box, "Plot", view_box=RadvizInteractiveViewBox)
        self.graph.hide_axes()

        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        SIZE_POLICY = (QSizePolicy.Minimum, QSizePolicy.Maximum)

        self.variables_selection = VariablesSelection()
        self.model_selected = VariableListModel(enable_dnd=True)
        self.model_other = VariableListModel(enable_dnd=True)
        self.variables_selection(self, self.model_selected, self.model_other)

        self.vizrank, self.btn_vizrank = RadvizVizRank.add_vizrank(
            self.controlArea, self, "Suggest features", self.vizrank_set_attrs
        )
        self.btn_vizrank.setSizePolicy(*SIZE_POLICY)
        self.variables_selection.add_remove.layout().addWidget(self.btn_vizrank)

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        pp_box = g.point_properties_box(self.controlArea)
        pp_box.setSizePolicy(*SIZE_POLICY)
        self.models = g.points_models

        box = gui.vBox(self.controlArea, "Plot Properties")
        box.setSizePolicy(*SIZE_POLICY)
        g.add_widget(g.JitterSizeSlider, box)

        g.add_widgets([g.ShowLegend, g.ClassDensity, g.LabelOnlySelected], box)

        zoom_select = self.graph.box_zoom_select(self.controlArea)
        zoom_select.setSizePolicy(*SIZE_POLICY)

        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(
            self.controlArea,
            self,
            "auto_commit",
            "Send Selection",
            auto_label="Send Automatically",
        )

        self.graph.zoom_actions(self)

        self._circle = QGraphicsEllipseItem()
        self._circle.setRect(QRectF(-1.0, -1.0, 2.0, 2.0))
        self._circle.setPen(pg.mkPen(QColor(0, 0, 0), width=2))

    def resizeEvent(self, event):
        self._update_points_labels()

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def vizrank_set_attrs(self, attrs):
        if not attrs:
            return
        self.variables_selection.display_none()
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [v for v in self.model_other if v not in attrs]

    def _new_plotdata(self):
        self.plotdata = namespace(
            valid_mask=None,
            embedding_coords=None,
            points=None,
            arcarrows=[],
            point_labels=[],
            rand=None,
            data=None,
        )

    def update_colors(self):
        self._vizrank_color_change()
        self.cb_class_density.setEnabled(self.graph.can_draw_density())

    def sizeHint(self):
        return QSize(800, 500)

    def clear(self):
        """
        Clear/reset the widget state
        """
        self.data = None
        self.model_selected.clear()
        self.model_other.clear()
        self._clear_plot()

    def _clear_plot(self):
        self._new_plotdata()
        self.graph.plot_widget.clear()

    def invalidate_plot(self):
        """
        Schedule a delayed replot.
        """
        if not self.__replot_requested:
            self.__replot_requested = True
            QApplication.postEvent(
                self, QEvent(self.ReplotRequest), Qt.LowEventPriority - 10
            )

    def init_attr_values(self):
        self.graph.set_domain(self.data)

    def _vizrank_color_change(self):
        attr_color = self.graph.attr_color
        is_enabled = (
            self.data is not None
            and not self.data.is_sparse()
            and (len(self.model_other) + len(self.model_selected)) > 3
            and len(self.data) > 1
        )
        self.btn_vizrank.setEnabled(
            is_enabled
            and attr_color is not None
            and not np.isnan(
                self.data.get_column_view(attr_color)[0].astype(float)
            ).all()
        )
        self.vizrank.initialize()

    @Inputs.data
    def set_data(self, data):
        """
        Set the input dataset and check if data is valid.

        Args:
            data (Orange.data.table): data instances
        """

        def sql(data):
            self.Information.sql_sampled_data.clear()
            if isinstance(data, SqlTable):
                if data.approx_len() < 4000:
                    data = Table(data)
                else:
                    self.Information.sql_sampled_data()
                    data_sample = data.sample_time(1, no_cache=True)
                    data_sample.download_data(2000, partial=True)
                    data = Table(data_sample)
            return data

        def settings(data):
            # get the default encoded state, replacing the position with Inf
            state = VariablesSelection.encode_var_state(
                [list(self.model_selected), list(self.model_other)]
            )
            state = {
                key: (source_ind, np.inf) for key, (source_ind, _) in state.items()
            }

            self.openContext(data.domain)
            selected_keys = [
                key for key, (sind, _) in self.variable_state.items() if sind == 0
            ]

            if set(selected_keys).issubset(set(state.keys())):
                pass

            # update the defaults state (the encoded state must contain
            # all variables in the input domain)
            state.update(self.variable_state)
            # ... and restore it with saved positions taking precedence over
            # the defaults
            selected, other = VariablesSelection.decode_var_state(
                state, [list(self.model_selected), list(self.model_other)]
            )
            return selected, other

        def is_sparse(data):
            if data.is_sparse():
                self.Error.sparse_data()
                data = None
            return data

        def are_features(data):
            domain = data.domain
            vars = [
                var
                for var in chain(domain.class_vars, domain.metas, domain.attributes)
                if var.is_primitive()
            ]
            if len(vars) < 3:
                self.Error.no_features()
                data = None
            return data

        def are_instances(data):
            if len(data) < 2:
                self.Error.no_instances()
                data = None
            return data

        self.clear_messages()
        self.btn_vizrank.setEnabled(False)
        self.closeContext()
        self.clear()
        self.information()
        self.Error.clear()
        for f in [sql, is_sparse, are_features, are_instances]:
            if data is None:
                break
            data = f(data)

        if data is not None:
            self.data = data
            self.init_attr_values()
            domain = data.domain
            vars = [
                v for v in chain(domain.metas, domain.attributes) if v.is_primitive()
            ]
            self.model_selected[:] = vars[:5]
            self.model_other[:] = vars[5:] + list(domain.class_vars)
            self.model_selected[:], self.model_other[:] = settings(data)
            self._selection = np.zeros(len(data), dtype=np.uint8)
            self.invalidate_plot()
        else:
            self.data = None

    @Inputs.data_subset
    def set_subset_data(self, subset):
        """
        Set the supplementary input subset dataset.

        Args:
            subset (Orange.data.table): subset of data instances
        """
        self.subset_data = subset
        self._subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if self.data is not None:
            self._clear_plot()
            if self.subset_data is not None and self._subset_mask is None:
                dataids = self.data.ids.ravel()
                subsetids = np.unique(self.subset_data.ids)
                self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True)
            self.setup_plot(reset_view=True)
            self.cb_class_density.setEnabled(self.graph.can_draw_density())
        else:
            self.init_attr_values()
            self.graph.new_data(None)
        self._vizrank_color_change()
        self.commit()

    def customEvent(self, event):
        if event.type() == OWRadviz.ReplotRequest:
            self.__replot_requested = False
            self._clear_plot()
            self.setup_plot(reset_view=True)
        else:
            super().customEvent(event)

    def closeContext(self):
        self.variable_state = VariablesSelection.encode_var_state(
            [list(self.model_selected), list(self.model_other)]
        )
        super().closeContext()

    def prepare_radviz_data(self, variables):
        ec, points, valid_mask = radviz(self.data, variables, self.plotdata.points)
        self.plotdata.embedding_coords = ec
        self.plotdata.points = points
        self.plotdata.valid_mask = valid_mask

    def setup_plot(self, reset_view=True):
        if self.data is None:
            return
        self.graph.jitter_continuous = True
        self.__replot_requested = False

        variables = list(self.model_selected)
        if len(variables) < 2:
            self.Warning.no_features()
            self.graph.new_data(None)
            return

        self.Warning.clear()
        self.prepare_radviz_data(variables)

        if self.plotdata.embedding_coords is None:
            return

        domain = self.data.domain
        new_metas = domain.metas + (self.variable_x, self.variable_y)
        domain = Domain(
            attributes=domain.attributes, class_vars=domain.class_vars, metas=new_metas
        )
        mask = self.plotdata.valid_mask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[mask] = self.plotdata.embedding_coords
        data = self.data.transform(domain)
        data[:, self.variable_x] = array[:, 0].reshape(-1, 1)
        data[:, self.variable_y] = array[:, 1].reshape(-1, 1)
        subset_data = (
            data[self._subset_mask & mask]
            if self._subset_mask is not None and len(self._subset_mask)
            else None
        )
        self.plotdata.data = data
        self.graph.new_data(data[mask], subset_data)
        if self._selection is not None:
            self.graph.selection = self._selection[self.plotdata.valid_mask]
        self.graph.update_data(self.variable_x, self.variable_y, reset_view=reset_view)
        self.graph.plot_widget.addItem(self._circle)
        self.graph.scatterplot_points = ScatterPlotItem(
            x=self.plotdata.points[:, 0], y=self.plotdata.points[:, 1]
        )
        self._update_points_labels()
        self.graph.plot_widget.addItem(self.graph.scatterplot_points)

    def randomize_indices(self):
        ec = self.plotdata.embedding_coords
        self.plotdata.rand = (
            np.random.choice(len(ec), MAX_POINTS, replace=False)
            if len(ec) > MAX_POINTS
            else None
        )

    def manual_move(self):
        self.__replot_requested = False

        if self.plotdata.rand is not None:
            rand = self.plotdata.rand
            valid_mask = self.plotdata.valid_mask
            data = self.data[valid_mask]
            selection = self._selection[valid_mask]
            selection = selection[rand]
            ec, _, valid_mask = radviz(
                data, list(self.model_selected), self.plotdata.points
            )
            assert sum(valid_mask) == len(data)
            data = data[rand]
            ec = ec[rand]
            data_x = data.X
            data_y = data.Y
            data_metas = data.metas
        else:
            self.prepare_radviz_data(list(self.model_selected))
            ec = self.plotdata.embedding_coords
            valid_mask = self.plotdata.valid_mask
            data_x = self.data.X[valid_mask]
            data_y = self.data.Y[valid_mask]
            data_metas = self.data.metas[valid_mask]
            selection = self._selection[valid_mask]

        attributes = (self.variable_x, self.variable_y) + self.data.domain.attributes
        domain = Domain(
            attributes=attributes,
            class_vars=self.data.domain.class_vars,
            metas=self.data.domain.metas,
        )
        data = Table.from_numpy(
            domain, X=np.hstack((ec, data_x)), Y=data_y, metas=data_metas
        )
        self.graph.new_data(data, None)
        self.graph.selection = selection
        self.graph.update_data(self.variable_x, self.variable_y, reset_view=True)
        self.graph.plot_widget.addItem(self._circle)
        self.graph.scatterplot_points = ScatterPlotItem(
            x=self.plotdata.points[:, 0], y=self.plotdata.points[:, 1]
        )
        self._update_points_labels()
        self.graph.plot_widget.addItem(self.graph.scatterplot_points)

    def _update_points_labels(self):
        if self.plotdata.points is None:
            return
        for point_label in self.plotdata.point_labels:
            self.graph.plot_widget.removeItem(point_label)
        self.plotdata.point_labels = []
        sx, sy = self.graph.view_box.viewPixelSize()

        for row in self.plotdata.points:
            ti = TextItem()
            metrics = QFontMetrics(ti.textItem.font())
            text_width = ((RANGE.width()) / 2.0 - np.abs(row[0])) / sx
            name = row[2].name
            ti.setText(name)
            ti.setTextWidth(text_width)
            ti.setColor(QColor(0, 0, 0))
            br = ti.boundingRect()
            width = (
                metrics.width(name) if metrics.width(name) < br.width() else br.width()
            )
            width = sx * (width + 5)
            height = sy * br.height()
            ti.setPos(row[0] - (row[0] < 0) * width, row[1] + (row[1] > 0) * height)
            self.plotdata.point_labels.append(ti)
            self.graph.plot_widget.addItem(ti)

    def _update_jitter(self):
        self.invalidate_plot()

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph()

    def _update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x, self.variable_y, reset_view=reset_view)

    def update_density(self):
        self._update_graph(reset_view=True)

    def selection_changed(self):
        if self.graph.selection is not None:
            self._selection[self.plotdata.valid_mask] = self.graph.selection
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        selected = annotated = components = None
        graph = self.graph
        if self.plotdata.data is not None:
            name = self.data.name
            data = self.plotdata.data
            mask = self.plotdata.valid_mask.astype(int)
            mask[mask == 1] = (
                graph.selection if graph.selection is not None else [False * len(mask)]
            )
            selection = (
                np.array([], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            )
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            comp_domain = Domain(
                self.plotdata.points[:, 2], metas=[StringVariable(name="component")]
            )

            metas = np.array([["RX"], ["RY"], ["angle"]])
            angle = np.arctan2(
                np.array(self.plotdata.points[:, 1].T, dtype=float),
                np.array(self.plotdata.points[:, 0].T, dtype=float),
            )
            components = Table.from_numpy(
                comp_domain,
                X=np.row_stack((self.plotdata.points[:, :2].T, angle)),
                metas=metas,
            )
            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (
                ("Color", name(self.graph.attr_color)),
                ("Label", name(self.graph.attr_label)),
                ("Shape", name(self.graph.attr_shape)),
                ("Size", name(self.graph.attr_size)),
                (
                    "Jittering",
                    self.graph.jitter_size != 0
                    and "{} %".format(self.graph.jitter_size),
                ),
            )
        )
        self.report_plot()
        if caption:
            self.report_caption(caption)
class OWTopicModeling(OWWidget, ConcurrentWidgetMixin):
    name = "Topic Modelling"
    description = "Uncover the hidden thematic structure in a corpus."
    icon = "icons/TopicModeling.svg"
    priority = 400
    keywords = ["LDA"]

    settingsHandler = DomainContextHandler()

    # Input/output
    class Inputs:
        corpus = Input("Corpus", Corpus)

    class Outputs:
        corpus = Output("Corpus", Table, default=True)
        selected_topic = Output("Selected Topic", Topic)
        all_topics = Output("All Topics", Topics)

    want_main_area = True

    methods = [(LsiWidget, 'lsi'), (LdaWidget, 'lda'), (HdpWidget, 'hdp'),
               (NmfWidget, 'nmf')]

    # Settings
    autocommit = settings.Setting(True)
    method_index = settings.Setting(0)

    lsi = settings.SettingProvider(LsiWidget)
    hdp = settings.SettingProvider(HdpWidget)
    lda = settings.SettingProvider(LdaWidget)
    nmf = settings.SettingProvider(NmfWidget)

    selection = settings.Setting(None, schema_only=True)

    control_area_width = 300

    class Warning(OWWidget.Warning):
        less_topics_found = Msg('Less topics found than requested.')

    class Error(OWWidget.Error):
        unexpected_error = Msg("{}")

    def __init__(self):
        super().__init__()
        ConcurrentWidgetMixin.__init__(self)

        self.corpus = None
        self.learning_thread = None
        self.__pending_selection = self.selection
        self.perplexity = "n/a"
        self.coherence = "n/a"

        # Commit button
        gui.auto_commit(self.buttonsArea,
                        self,
                        'autocommit',
                        'Commit',
                        box=False)

        button_group = QButtonGroup(self, exclusive=True)
        button_group.buttonClicked[int].connect(self.change_method)

        self.widgets = []
        method_layout = QVBoxLayout()
        self.controlArea.layout().addLayout(method_layout)
        for i, (method, attr_name) in enumerate(self.methods):
            widget = method(self, title='Options')
            widget.setFixedWidth(self.control_area_width)
            widget.valueChanged.connect(self.commit.deferred)
            self.widgets.append(widget)
            setattr(self, attr_name, widget)

            rb = QRadioButton(text=widget.Model.name)
            button_group.addButton(rb, i)
            method_layout.addWidget(rb)
            method_layout.addWidget(widget)

        button_group.button(self.method_index).setChecked(True)
        self.toggle_widgets()
        method_layout.addStretch()

        box = gui.vBox(self.controlArea, "Topic evaluation")
        gui.label(box, self, "Log perplexity: %(perplexity)s")
        gui.label(box, self, "Topic coherence: %(coherence)s")
        self.controlArea.layout().insertWidget(1, box)

        # Topics description
        self.topic_desc = TopicViewer()
        self.topic_desc.topicSelected.connect(self.send_topic_by_id)
        self.mainArea.layout().addWidget(self.topic_desc)
        self.topic_desc.setFocus()

    @Inputs.corpus
    def set_data(self, data=None):
        self.Warning.less_topics_found.clear()
        self.corpus = data
        self.apply()

    @gui.deferred
    def commit(self):
        if self.corpus is not None:
            self.apply()

    @property
    def model(self):
        return self.widgets[self.method_index].model

    def change_method(self, new_index):
        if self.method_index != new_index:
            self.method_index = new_index
            self.toggle_widgets()
            self.commit.deferred()

    def toggle_widgets(self):
        for i, widget in enumerate(self.widgets):
            widget.setVisible(i == self.method_index)

    def apply(self):
        self.cancel()
        self.topic_desc.clear()
        if self.corpus is not None:
            self.Warning.less_topics_found.clear()
            self.start(_run, self.corpus, self.model)
        else:
            self.topic_desc.clear()
            self.Outputs.corpus.send(None)
            self.Outputs.selected_topic.send(None)
            self.Outputs.all_topics.send(None)

    def on_done(self, corpus):
        self.Outputs.corpus.send(corpus)
        pos_tags = self.corpus.pos_tags is not None
        self.topic_desc.show_model(self.model, pos_tags=pos_tags)
        if self.__pending_selection:
            self.topic_desc.select(self.__pending_selection)
            self.__pending_selection = None

        if self.model.actual_topics != self.model.num_topics:
            self.Warning.less_topics_found()

        if self.model.name == "Latent Dirichlet Allocation":
            bound = self.model.model.log_perplexity(
                infer_ngrams_corpus(corpus))
            self.perplexity = "{:.5f}".format(np.exp2(-bound))
        cm = CoherenceModel(model=self.model.model,
                            texts=corpus.tokens,
                            corpus=corpus,
                            coherence="c_v")
        coherence = cm.get_coherence()
        self.coherence = "{:.5f}".format(coherence)

        self.Outputs.all_topics.send(self.model.get_all_topics_table())

    def on_exception(self, ex: Exception):
        self.Error.unexpected_error(str(ex))

    def on_partial_result(self, result: Any) -> None:
        pass

    def send_report(self):
        self.report_items(*self.widgets[self.method_index].report_model())
        if self.corpus is not None:
            self.report_items('Topics', self.topic_desc.report())

    def send_topic_by_id(self, topic_id=None):
        self.selection = topic_id
        if self.model.model and topic_id is not None:
            self.Outputs.selected_topic.send(
                self.model.get_topics_table_by_id(topic_id))
Exemple #11
0
class OWPreprocess(OWWidget):

    name = '文本预处理'
    description = '构建文本预处理的管道'
    icon = 'icons/TextPreprocess.svg'
    priority = 200

    class Inputs:
        corpus = Input("Corpus", Corpus)

    class Outputs:
        corpus = Output("Corpus", Corpus)

    autocommit = settings.Setting(True)

    preprocessors = [
        TransformationModule,
        TokenizerModule,
        NormalizationModule,
        FilteringModule,
        NgramsModule,
        POSTaggingModule,
    ]

    transformers = settings.SettingProvider(TransformationModule)
    tokenizer = settings.SettingProvider(TokenizerModule)
    normalizer = settings.SettingProvider(NormalizationModule)
    filters = settings.SettingProvider(FilteringModule)
    ngrams_range = settings.SettingProvider(NgramsModule)
    pos_tagger = settings.SettingProvider(POSTaggingModule)

    control_area_width = 250
    buttons_area_orientation = Qt.Vertical

    UserAdviceMessages = [
        widget.Message("部分预处理所需要的数据(例如词汇关系、停用词、标点符号规则等)是从NLTK包中获取的,",
                       "这些数据可以从{}下载。".format(nltk_data_dir()))
    ]

    class Error(OWWidget.Error):
        stanford_tagger = Msg("无法加载Stanford POS Tagger\n{}")
        stopwords_encoding = Msg("停用词表编码不正确,请使用 UTF-8 再试一次。")
        lexicon_encoding = Msg("词典编码不正确,请使用 UTF-8 再试一次。")
        error_reading_stopwords = Msg("读取文件错误: {}")
        error_reading_lexicon = Msg("读取文件错误: {}")

    class Warning(OWWidget.Warning):
        no_token_left = Msg('没有标记输出,请重新配置')
        udpipe_offline = Msg('没有网络连接,UDPipe 只加载本地模型')
        udpipe_offline_no_models = Msg('没有网络连接,UDPipe无本地模型')

    def __init__(self, parent=None):
        super().__init__(parent)
        self.corpus = None
        self.initial_ngram_range = None  # initial range of input corpus — used for inplace
        self.preprocessor = preprocess.Preprocessor()

        # -- INFO --
        info_box = gui.widgetBox(self.controlArea, '基本信息')
        info_box.setFixedWidth(self.control_area_width)
        self.controlArea.layout().addStretch()
        self.info_label = gui.label(info_box, self, '')
        self.update_info()

        # -- PIPELINE --
        frame = QFrame()
        frame.setContentsMargins(0, 0, 0, 0)
        frame.setFrameStyle(QFrame.Box)
        frame.setStyleSheet('.QFrame { border: 1px solid #B3B3B3; }')
        frame_layout = QVBoxLayout()
        frame_layout.setContentsMargins(0, 0, 0, 0)
        frame_layout.setSpacing(0)
        frame.setLayout(frame_layout)

        self.stages = []
        for stage in self.preprocessors:
            widget = stage(self)
            self.stages.append(widget)
            setattr(self, stage.attribute, widget)
            frame_layout.addWidget(widget)
            widget.change_signal.connect(self.settings_invalidated)

        frame_layout.addStretch()
        self.scroll = QScrollArea()
        self.scroll.setWidget(frame)
        self.scroll.setWidgetResizable(True)
        self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.scroll.resize(frame_layout.sizeHint())
        self.scroll.setMinimumHeight(500)
        self.set_minimal_width()
        self.mainArea.layout().addWidget(self.scroll)

        # Buttons area
        self.report_button.setFixedWidth(self.control_area_width)

        commit_button = gui.auto_commit(self.buttonsArea,
                                        self,
                                        'autocommit',
                                        '提交',
                                        '自动提交',
                                        box=False)
        commit_button.setFixedWidth(self.control_area_width - 5)

        self.buttonsArea.layout().addWidget(commit_button)

    @Inputs.corpus
    def set_data(self, data=None):
        self.corpus = data.copy() if data is not None else None
        self.initial_ngram_range = data.ngram_range if data is not None else None
        self.commit()

    def update_info(self, corpus=None):
        if corpus is not None:
            info = '文档数量: {}\n' \
                   '标记数量: {}\n'\
                   '类型数量: {}'\
                   .format(len(corpus), sum(map(len, corpus.tokens)), len(corpus.dictionary))
        else:
            info = '没有数据集'
        self.info_label.setText(info)

    def commit(self):
        self.Warning.no_token_left.clear()
        if self.corpus is not None:
            self.apply()
        else:
            self.update_info()
            self.Outputs.corpus.send(None)

    def apply(self):
        self.preprocess()

    @asynchronous
    def preprocess(self):
        for module in self.stages:
            setattr(self.preprocessor, module.attribute, module.value)
        self.corpus.pos_tags = None  # reset pos_tags and ngrams_range
        self.corpus.ngram_range = self.initial_ngram_range
        return self.preprocessor(self.corpus,
                                 inplace=True,
                                 on_progress=self.on_progress)

    @preprocess.on_start
    def on_start(self):
        self.progressBarInit(None)

    @preprocess.callback
    def on_progress(self, i):
        self.progressBarSet(i, None)

    @preprocess.on_result
    def on_result(self, result):
        self.update_info(result)
        if result is not None and len(result.dictionary) == 0:
            self.Warning.no_token_left()
            result = None
        self.Outputs.corpus.send(result)
        self.progressBarFinished(None)

    def set_minimal_width(self):
        max_width = 250
        for widget in self.stages:
            if widget.enabled:
                max_width = max(max_width, widget.sizeHint().width())
        self.scroll.setMinimumWidth(max_width + 20)

    @pyqtSlot()
    def settings_invalidated(self):
        self.set_minimal_width()
        self.commit()

    def send_report(self):
        self.report_items('Preprocessor', self.preprocessor.report())
Exemple #12
0
class OWTopicModeling(OWWidget):
    name = "主题模型(Topic Modelling)"
    description = "展现语料库隐藏的语义结构."
    icon = "icons/TopicModeling.svg"
    priority = 400
    keywords = ["LDA", 'zhutimoxing']
    category = 'text'

    settingsHandler = DomainContextHandler()

    # Input/output
    class Inputs:
        corpus = Input('语料库(Corpus)', Corpus, replaces=['Corpus'])

    class Outputs:
        corpus = Output('语料库(Corpus)', Table, replaces=['Corpus'])
        selected_topic = Output("选中的主题(Selected Topic)", Topic, replaces=['Selected Topic'])
        all_topics = Output("所有主题(All Topics)", Table, replaces=['All Topics'])

    want_main_area = True

    methods = [
        (LsiWidget, 'lsi'),
        (LdaWidget, 'lda'),
        (HdpWidget, 'hdp'),
    ]

    # Settings
    autocommit = settings.Setting(True)
    method_index = settings.Setting(0)

    lsi = settings.SettingProvider(LsiWidget)
    hdp = settings.SettingProvider(HdpWidget)
    lda = settings.SettingProvider(LdaWidget)

    selection = settings.Setting(None, schema_only=True)

    control_area_width = 300

    class Warning(OWWidget.Warning):
        less_topics_found = Msg('比设定的主题少.')

    def __init__(self):
        super().__init__()
        self.corpus = None
        self.learning_thread = None
        self.__pending_selection = self.selection

        # Commit button
        gui.auto_commit(self.buttonsArea, self, 'autocommit', 'Commit', box=False)

        button_group = QButtonGroup(self, exclusive=True)
        button_group.buttonClicked[int].connect(self.change_method)

        self.widgets = []
        method_layout = QVBoxLayout()
        self.controlArea.layout().addLayout(method_layout)
        for i, (method, attr_name) in enumerate(self.methods):
            widget = method(self, title='选项')
            widget.setFixedWidth(self.control_area_width)
            widget.valueChanged.connect(self.commit)
            self.widgets.append(widget)
            setattr(self, attr_name, widget)

            rb = QRadioButton(text=widget.Model.name)
            button_group.addButton(rb, i)
            method_layout.addWidget(rb)
            method_layout.addWidget(widget)

        button_group.button(self.method_index).setChecked(True)
        self.toggle_widgets()
        method_layout.addStretch()

        # Topics description
        self.topic_desc = TopicViewer()
        self.topic_desc.topicSelected.connect(self.send_topic_by_id)
        self.mainArea.layout().addWidget(self.topic_desc)
        self.topic_desc.setFocus()

    @Inputs.corpus
    def set_data(self, data=None):
        self.Warning.less_topics_found.clear()
        self.corpus = data
        self.apply()

    def commit(self):
        if self.corpus is not None:
            self.apply()

    @property
    def model(self):
        return self.widgets[self.method_index].model

    def change_method(self, new_index):
        if self.method_index != new_index:
            self.method_index = new_index
            self.toggle_widgets()
            self.commit()

    def toggle_widgets(self):
        for i, widget in enumerate(self.widgets):
            widget.setVisible(i == self.method_index)

    def apply(self):
        self.learning_task.stop()
        if self.corpus is not None:
            self.learning_task()
        else:
            self.on_result(None)

    @asynchronous
    def learning_task(self):
        return self.model.fit_transform(self.corpus.copy(), chunk_number=100,
                                        on_progress=self.on_progress)

    @learning_task.on_start
    def on_start(self):
        self.Warning.less_topics_found.clear()
        self.progressBarInit()
        self.topic_desc.clear()

    @learning_task.on_result
    def on_result(self, corpus):
        self.progressBarFinished()
        self.Outputs.corpus.send(corpus)
        if corpus is None:
            self.topic_desc.clear()
            self.Outputs.selected_topic.send(None)
            self.Outputs.all_topics.send(None)
        else:
            self.topic_desc.show_model(self.model)
            if self.__pending_selection:
                self.topic_desc.select(self.__pending_selection)
                self.__pending_selection = None
            if self.model.actual_topics != self.model.num_topics:
                self.Warning.less_topics_found()
            self.Outputs.all_topics.send(self.model.get_all_topics_table())

    @learning_task.callback
    def on_progress(self, p):
        self.progressBarSet(100 * p)

    def send_report(self):
        self.report_items(*self.widgets[self.method_index].report_model())
        if self.corpus is not None:
            self.report_items('Topics', self.topic_desc.report())

    def send_topic_by_id(self, topic_id=None):
        self.selection = topic_id
        if self.model.model and topic_id is not None:
            self.Outputs.selected_topic.send(
                self.model.get_topics_table_by_id(topic_id))
Exemple #13
0
class OWFreeViz(widget.OWWidget):
    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    #: Initialization type
    Circular, Random = 0, 1

    jitter_sizes = [0, 0.1, 0.5, 1, 2]

    settings_version = 2
    settingsHandler = settings.DomainContextHandler()

    radius = settings.Setting(0)
    initialization = settings.Setting(Circular)
    auto_commit = settings.Setting(True)

    resolution = 256
    graph = settings.SettingProvider(OWFreeVizGraph)

    ReplotRequest = QEvent.registerEventType()

    graph_name = "graph.plot_widget.plotItem"

    class Warning(widget.OWWidget.Warning):
        sparse_not_supported = widget.Msg("Sparse data is ignored.")

    class Error(widget.OWWidget.Error):
        no_class_var = widget.Msg("Need a class variable")
        not_enough_class_vars = widget.Msg("Needs discrete class variable " \
                                          "with at lest 2 values")
        features_exceeds_instances = widget.Msg("Algorithm should not be used when " \
                                                "number of features exceeds the number " \
                                                "of instances.")
        too_many_data_instances = widget.Msg("Cannot handle so large data.")
        no_valid_data = widget.Msg("No valid data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.subset_data = None
        self._subset_mask = None
        self._validmask = None
        self._X = None
        self._Y = None
        self._selection = None
        self.__replot_requested = False

        self.variable_x = ContinuousVariable("freeviz-x")
        self.variable_y = ContinuousVariable("freeviz-y")

        box0 = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWFreeVizGraph(self,
                                    box0,
                                    "Plot",
                                    view_box=FreeVizInteractiveViewBox)
        box0.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        box = gui.widgetBox(self.controlArea, "Optimization", spacing=10)
        form = QFormLayout(labelAlignment=Qt.AlignLeft,
                           formAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
                           verticalSpacing=10)
        form.addRow(
            "Initialization",
            gui.comboBox(box,
                         self,
                         "initialization",
                         items=["Circular", "Random"],
                         callback=self.reset_initialization))
        box.layout().addLayout(form)

        self.btn_start = gui.button(widget=box,
                                    master=self,
                                    label="Optimize",
                                    callback=self.toogle_start,
                                    enabled=False)

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        g.point_properties_box(self.controlArea)
        self.models = g.points_models

        box = gui.widgetBox(self.controlArea, "Show anchors")
        self.rslider = gui.hSlider(box,
                                   self,
                                   "radius",
                                   minValue=0,
                                   maxValue=100,
                                   step=5,
                                   label="Radius",
                                   createLabel=False,
                                   ticks=True,
                                   callback=self.update_radius)
        self.rslider.setTickInterval(0)
        self.rslider.setPageStep(10)

        box = gui.vBox(self.controlArea, "Plot Properties")

        g.add_widgets([g.JitterSizeSlider], box)
        g.add_widgets([g.ShowLegend, g.ClassDensity, g.LabelOnlySelected], box)

        self.graph.box_zoom_select(self.controlArea)
        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")
        self.graph.zoom_actions(self)
        # FreeViz
        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

        self._new_plotdata()

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def update_radius(self):
        # Update the anchor/axes visibility
        assert not self.plotdata is None
        if self.plotdata.hidecircle is None:
            return

        minradius = self.radius / 100 + 1e-5
        for anchor, item in zip(self.plotdata.anchors,
                                self.plotdata.anchoritem):
            item.setVisible(np.linalg.norm(anchor) > minradius)
        self.plotdata.hidecircle.setRect(
            QRectF(-minradius, -minradius, 2 * minradius, 2 * minradius))

    def toogle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            if isinstance(self, OWFreeViz):
                self.btn_start.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def _start(self):
        """
        Start the projection optimization.
        """
        assert not self.plotdata is None

        X, Y = self.plotdata.X, self.plotdata.Y
        anchors = self.plotdata.anchors

        def update_freeviz(interval, initial):
            anchors = initial
            while True:
                res = FreeViz.freeviz(X,
                                      Y,
                                      scale=False,
                                      center=False,
                                      initial=anchors,
                                      maxiter=interval)
                _, anchors_new = res[:2]
                yield res[:2]
                if np.allclose(anchors, anchors_new, rtol=1e-5, atol=1e-4):
                    return

                anchors = anchors_new

        interval = 10  # TODO

        self._loop.setCoroutine(update_freeviz(interval, anchors))
        self.btn_start.setText("Stop")
        self.progressBarInit(processEvents=False)
        self.setBlocking(True)
        self.setStatusMessage("Optimizing")

    def reset_initialization(self):
        """
        Reset the current 'anchor' initialization, and restart the
        optimization if necessary.
        """
        running = self._loop.isRunning()

        if running:
            self._loop.cancel()

        if self.data is not None:
            self._clear_plot()
            self.setup_plot()

        if running:
            self._start()

    def __set_projection(self, res):
        # Set/update the projection matrix and coordinate embeddings
        # assert self.plotdata is not None, "__set_projection call unexpected"
        assert not self.plotdata is None
        increment = 1  # TODO
        self.progressBarAdvance(increment * 100. / MAX_ITERATIONS,
                                processEvents=False)  # TODO
        embedding_coords, projection = res
        self.plotdata.embedding_coords = embedding_coords
        self.plotdata.anchors = projection
        self._update_xy()
        self.update_radius()
        self.update_density()

    def __freeviz_finished(self):
        # Projection optimization has finished
        self.btn_start.setText("Optimize")
        self.setStatusMessage("")
        self.setBlocking(False)
        self.progressBarFinished(processEvents=False)
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def _update_xy(self):
        # Update the plotted embedding coordinates
        self.graph.plot_widget.clear()
        coords = self.plotdata.embedding_coords
        radius = np.max(np.linalg.norm(coords, axis=1))
        self.plotdata.embedding_coords = coords / radius
        self.plot(
            show_anchors=(len(self.data.domain.attributes) < MAX_ANCHORS))

    def _new_plotdata(self):
        self.plotdata = namespace(
            validmask=None,
            embedding_coords=None,
            anchors=[],
            anchoritem=[],
            X=None,
            Y=None,
            indicators=[],
            hidecircle=None,
            data=None,
            items=[],
            topattrs=None,
            rand=None,
            selection=None,  # np.array
        )

    def _anchor_circle(self):
        # minimum visible anchor radius (radius)
        minradius = self.radius / 100 + 1e-5
        for item in chain(self.plotdata.anchoritem, self.plotdata.items):
            self.viewbox.removeItem(item)
        self.plotdata.anchoritem = []
        self.plotdata.items = []
        for anchor, var in zip(self.plotdata.anchors,
                               self.data.domain.attributes):
            if True or np.linalg.norm(anchor) > minradius:
                axitem = AnchorItem(
                    line=QLineF(0, 0, *anchor),
                    text=var.name,
                )
                axitem.setVisible(np.linalg.norm(anchor) > minradius)
                axitem.setPen(pg.mkPen((100, 100, 100)))
                axitem.setArrowVisible(True)
                self.plotdata.anchoritem.append(axitem)
                self.viewbox.addItem(axitem)

        hidecircle = QGraphicsEllipseItem()
        hidecircle.setRect(
            QRectF(-minradius, -minradius, 2 * minradius, 2 * minradius))

        _pen = QPen(Qt.lightGray, 1)
        _pen.setCosmetic(True)
        hidecircle.setPen(_pen)
        self.viewbox.addItem(hidecircle)
        self.plotdata.items.append(hidecircle)
        self.plotdata.hidecircle = hidecircle

    def update_colors(self):
        pass

    def sizeHint(self):
        return QSize(800, 500)

    def _clear(self):
        """
        Clear/reset the widget state
        """
        self._loop.cancel()
        self.data = None
        self._selection = None
        self._clear_plot()

    def _clear_plot(self):
        for item in chain(self.plotdata.anchoritem, self.plotdata.items):
            self.viewbox.removeItem(item)
        self.graph.plot_widget.clear()
        self._new_plotdata()

    def init_attr_values(self):
        self.graph.set_domain(self.data)

    @Inputs.data
    def set_data(self, data):
        self.clear_messages()
        self._clear()
        self.closeContext()
        if data is not None:
            if data and data.is_sparse():
                self.Warning.sparse_not_supported()
                data = None
            elif data.domain.class_var is None:
                self.Error.no_class_var()
                data = None
            elif data.domain.class_var.is_discrete and \
                            len(data.domain.class_var.values) < 2:
                self.Error.not_enough_class_vars()
                data = None
            if data and len(data.domain.attributes) > data.X.shape[0]:
                self.Error.features_exceeds_instances()
                data = None
        if data is not None:
            valid_instances_count = self._prepare_freeviz_data(data)
            if valid_instances_count > MAX_INSTANCES:
                self.Error.too_many_data_instances()
                data = None
            elif valid_instances_count == 0:
                self.Error.no_valid_data()
                data = None
        self.data = data
        self.init_attr_values()
        if data is not None:
            self.cb_class_density.setEnabled(data.domain.has_discrete_class)
            self.openContext(data)
            self.btn_start.setEnabled(True)
        else:
            self.btn_start.setEnabled(False)
            self._X = self._Y = None
            self.graph.new_data(None, None)

    @Inputs.data_subset
    def set_subset_data(self, subset):
        self.subset_data = subset
        self.plotdata.subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if all(v is not None for v in [self.data, self.subset_data]):
            dataids = self.data.ids.ravel()
            subsetids = np.unique(self.subset_data.ids)
            self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True)
        if self._X is not None:
            self.setup_plot(True)
        self.commit()

    def customEvent(self, event):
        if event.type() == OWFreeViz.ReplotRequest:
            self.__replot_requested = False
            self.setup_plot()
        else:
            super().customEvent(event)

    def _prepare_freeviz_data(self, data):
        X = data.X
        Y = data.Y
        mask = np.bitwise_or.reduce(np.isnan(X), axis=1)
        mask |= np.isnan(Y)
        validmask = ~mask
        X = X[validmask, :]
        Y = Y[validmask]

        if not len(X):
            self._X = None
            return 0

        if data.domain.class_var.is_discrete:
            Y = Y.astype(int)
        X = (X - np.mean(X, axis=0))
        span = np.ptp(X, axis=0)
        X[:, span > 0] /= span[span > 0].reshape(1, -1)
        self._X = X
        self._Y = Y
        self._validmask = validmask
        return len(X)

    def setup_plot(self, reset_view=True):
        assert not self._X is None

        self.graph.jitter_continuous = True
        self.__replot_requested = False

        X = self.plotdata.X = self._X
        self.plotdata.Y = self._Y
        self.plotdata.validmask = self._validmask
        self.plotdata.selection = self._selection if self._selection is not None else \
            np.zeros(len(self._validmask), dtype=np.uint8)
        anchors = self.plotdata.anchors
        if len(anchors) == 0:
            if self.initialization == self.Circular:
                anchors = FreeViz.init_radial(X.shape[1])
            else:
                anchors = FreeViz.init_random(X.shape[1], 2)

        EX = np.dot(X, anchors)
        c = np.zeros((X.shape[0], X.shape[1]))
        for i in range(X.shape[0]):
            c[i] = np.argsort((np.power(X[i] * anchors[:, 0], 2) +
                               np.power(X[i] * anchors[:, 1], 2)))[::-1]
        self.plotdata.topattrs = np.array(c, dtype=int)[:, :10]
        radius = np.max(np.linalg.norm(EX, axis=1))

        self.plotdata.anchors = anchors

        coords = (EX / radius)
        self.plotdata.embedding_coords = coords
        if reset_view:
            self.viewbox.setRange(RANGE)
            self.viewbox.setAspectLocked(True, 1)
        self.plot(reset_view=reset_view)

    def randomize_indices(self):
        X = self._X
        self.plotdata.rand = np.random.choice(len(X), MAX_POINTS, replace=False) \
            if len(X) > MAX_POINTS else None

    def manual_move_anchor(self, show_anchors=True):
        self.__replot_requested = False
        X = self.plotdata.X = self._X
        anchors = self.plotdata.anchors
        validmask = self.plotdata.validmask
        EX = np.dot(X, anchors)
        data_x = self.data.X[validmask]
        data_y = self.data.Y[validmask]
        radius = np.max(np.linalg.norm(EX, axis=1))
        if self.plotdata.rand is not None:
            rand = self.plotdata.rand
            EX = EX[rand]
            data_x = data_x[rand]
            data_y = data_y[rand]
            selection = self.plotdata.selection[validmask]
            selection = selection[rand]
        else:
            selection = self.plotdata.selection[validmask]
        coords = (EX / radius)

        if show_anchors:
            self._anchor_circle()
        attributes = () + self.data.domain.attributes + (self.variable_x,
                                                         self.variable_y)
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars)
        data = Table.from_numpy(domain,
                                X=np.hstack((data_x, coords)),
                                Y=data_y)
        self.graph.new_data(data, None)
        self.graph.selection = selection
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=False)

    def plot(self, reset_view=False, show_anchors=True):
        if show_anchors:
            self._anchor_circle()
        attributes = () + self.data.domain.attributes + (self.variable_x,
                                                         self.variable_y)
        domain = Domain(attributes=attributes,
                        class_vars=self.data.domain.class_vars,
                        metas=self.data.domain.metas)
        mask = self.plotdata.validmask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[mask] = self.plotdata.embedding_coords
        data = self.data.transform(domain)
        data[:, self.variable_x] = array[:, 0].reshape(-1, 1)
        data[:, self.variable_y] = array[:, 1].reshape(-1, 1)
        subset_data = data[self._subset_mask & mask]\
            if self._subset_mask is not None and len(self._subset_mask) else None
        self.plotdata.data = data
        self.graph.new_data(data[mask], subset_data)
        if self.plotdata.selection is not None:
            self.graph.selection = self.plotdata.selection[
                self.plotdata.validmask]
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=reset_view)

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph()

    def _update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        assert not self.graph.data is None
        self.graph.update_data(self.variable_x, self.variable_y, reset_view)

    def update_density(self):
        if self.graph.data is None:
            return
        self._update_graph(reset_view=False)

    def selection_changed(self):
        if self.graph.selection is not None:
            pd = self.plotdata
            pd.selection[pd.validmask] = self.graph.selection
            self._selection = pd.selection
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        selected = annotated = components = None
        graph = self.graph
        if self.data is not None and self.plotdata.validmask is not None:
            name = self.data.name
            metas = () + self.data.domain.metas + (self.variable_x,
                                                   self.variable_y)
            domain = Domain(attributes=self.data.domain.attributes,
                            class_vars=self.data.domain.class_vars,
                            metas=metas)
            data = self.plotdata.data.transform(domain)
            validmask = self.plotdata.validmask
            mask = np.array(validmask, dtype=int)
            mask[mask == 1] = graph.selection if graph.selection is not None \
                else [False * len(mask)]
            selection = np.array(
                [], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes
            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

            comp_domain = Domain(self.data.domain.attributes,
                                 metas=[StringVariable(name='component')])

            metas = np.array([["FreeViz 1"], ["FreeViz 2"]])
            components = Table.from_numpy(comp_domain,
                                          X=self.plotdata.anchors.T,
                                          metas=metas)

            components.name = name + ": components"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)
Exemple #14
0
class OWFreeViz(OWAnchorProjectionWidget):
    MAX_ITERATIONS = 1000
    MAX_INSTANCES = 10000

    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240
    keywords = ["viz"]

    settings_version = 3
    initialization = settings.Setting(InitType.Circular)
    GRAPH_CLASS = OWFreeVizGraph
    graph = settings.SettingProvider(OWFreeVizGraph)
    embedding_variables_names = ("freeviz-x", "freeviz-y")

    class Error(OWAnchorProjectionWidget.Error):
        no_class_var = widget.Msg("Data has no target variable")
        not_enough_class_vars = widget.Msg(
            "Target variable is not at least binary")
        features_exceeds_instances = widget.Msg(
            "Number of features exceeds the number of instances.")
        too_many_data_instances = widget.Msg("Data is too large.")

    def __init__(self):
        super().__init__()
        self._X = None
        self._Y = None

        # FreeViz
        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

    def _add_controls(self):
        self.__add_controls_start_box()
        super()._add_controls()
        self.graph.gui.add_control(self._effects_box,
                                   gui.hSlider,
                                   "Hide radius:",
                                   master=self.graph,
                                   value="hide_radius",
                                   minValue=0,
                                   maxValue=100,
                                   step=10,
                                   createLabel=False,
                                   callback=self.__radius_slider_changed)

    def __add_controls_start_box(self):
        box = gui.vBox(self.controlArea, box=True)
        gui.comboBox(box,
                     self,
                     "initialization",
                     label="Initialization:",
                     items=InitType.items(),
                     orientation=Qt.Horizontal,
                     labelWidth=90,
                     callback=self.__init_combo_changed)
        self.btn_start = gui.button(box,
                                    self,
                                    "Optimize",
                                    self.__toggle_start,
                                    enabled=False)

    def __radius_slider_changed(self):
        self.graph.update_radius()

    def __toggle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            self.btn_start.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def __init_combo_changed(self):
        if self.data is None:
            return
        running = self._loop.isRunning()
        if running:
            self._loop.cancel()
        self.init_embedding_coords()
        self.graph.update_coordinates()
        if running:
            self._start()

    def _start(self):
        def update_freeviz(anchors):
            while True:
                _, projection, *_ = FreeViz.freeviz(self._X,
                                                    self._Y,
                                                    scale=False,
                                                    center=False,
                                                    initial=anchors,
                                                    maxiter=10)
                yield projection
                if np.allclose(anchors, projection, rtol=1e-5, atol=1e-4):
                    return
                anchors = projection

        self.graph.set_sample_size(self.SAMPLE_SIZE)
        self._loop.setCoroutine(update_freeviz(self.projection))
        self.btn_start.setText("Stop")
        self.progressBarInit()
        self.setBlocking(True)
        self.setStatusMessage("Optimizing")

    def __set_projection(self, projection):
        # Set/update the projection matrix and coordinate embeddings
        self.progressBarAdvance(100. / self.MAX_ITERATIONS)
        self.projection = projection
        self.graph.update_coordinates()

    def __freeviz_finished(self):
        self.graph.set_sample_size(None)
        self.btn_start.setText("Optimize")
        self.setStatusMessage("")
        self.setBlocking(False)
        self.progressBarFinished()
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            class_var = self.data.domain.class_var
            if class_var is None:
                error(self.Error.no_class_var)
            elif class_var.is_discrete and len(np.unique(self.data.Y)) < 2:
                error(self.Error.not_enough_class_vars)
            elif len(self.data.domain.attributes) < 2:
                error(self.Error.not_enough_features)
            elif len(self.data.domain.attributes) > self.data.X.shape[0]:
                error(self.Error.features_exceeds_instances)
            else:
                self.valid_data = np.all(np.isfinite(self.data.X), axis=1) & \
                                  np.isfinite(self.data.Y)
                n_valid = np.sum(self.valid_data)
                if n_valid > self.MAX_INSTANCES:
                    error(self.Error.too_many_data_instances)
                elif n_valid == 0:
                    error(self.Error.no_valid_data)
        self.btn_start.setEnabled(self.data is not None)

    def set_data(self, data):
        super().set_data(data)
        if self.data is not None:
            self.prepare_projection_data()
            self.init_embedding_coords()

    def prepare_projection_data(self):
        if not np.any(self.valid_data):
            self._X = self._Y = self.valid_data = None
            return

        self._X = self.data.X.copy()
        self._X -= np.nanmean(self._X, axis=0)
        span = np.ptp(self._X[self.valid_data], axis=0)
        self._X[:, span > 0] /= span[span > 0].reshape(1, -1)

        self._Y = self.data.Y
        if self.data.domain.class_var.is_discrete:
            self._Y = self._Y.astype(int)

    def init_embedding_coords(self):
        self.projection = FreeViz.init_radial(self._X.shape[1]) \
            if self.initialization == InitType.Circular \
            else FreeViz.init_random(self._X.shape[1], 2)

    def get_embedding(self):
        if self.data is None:
            return None
        embedding = np.dot(self._X, self.projection)
        embedding /= \
            np.max(np.linalg.norm(embedding[self.valid_data], axis=1)) or 1
        return embedding

    def get_anchors(self):
        if self.projection is None:
            return None, None
        return self.projection, [a.name for a in self.data.domain.attributes]

    def send_components(self):
        components = None
        if self.data is not None and self.valid_data is not None:
            meta_attrs = [StringVariable(name='component')]
            domain = Domain(self.data.domain.attributes, metas=meta_attrs)
            metas = np.array([["FreeViz 1"], ["FreeViz 2"]])
            components = Table(domain, self.projection.T, metas=metas)
            components.name = self.data.name
        self.Outputs.components.send(components)

    def clear(self):
        super().clear()
        self._loop.cancel()
        self._X = None
        self._Y = None

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            if "radius" in _settings:
                _settings["graph"]["hide_radius"] = _settings["radius"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
class OWGenialisExpressions(widget.OWWidget, ConcurrentWidgetMixin):
    name = 'Genialis Expressions'
    priority = 30
    want_main_area = True
    want_control_area = True
    icon = '../widgets/icons/OWGenialisExpressions.svg'

    pagination_availability = pyqtSignal(bool, bool)

    norm_component = settings.SettingProvider(NormalizationComponent)
    pagination_component = settings.SettingProvider(PaginationComponent)
    filter_component = settings.SettingProvider(CollapsibleFilterComponent)

    exp_type: int
    exp_type = settings.Setting(1, schema_only=True)

    proc_slug: int
    proc_slug = settings.Setting(0, schema_only=True)

    exp_source: int
    exp_source = settings.Setting(0, schema_only=True)

    append_qc_data: bool
    append_qc_data = settings.Setting(False, schema_only=True)

    auto_commit: bool
    auto_commit = settings.Setting(False, schema_only=True)

    class Outputs:
        table = Output('Expressions', Table)

    class Warning(widget.OWWidget.Warning):
        no_expressions = Msg('Expression data objects not found.')
        no_data_objects = Msg(
            'No expression data matches the selected options.')
        unexpected_feature_type = Msg(
            'Can not import expression data, unexpected feature type "{}".')
        multiple_feature_type = Msg(
            'Can not import expression data, multiple feature types found.')

    def __init__(self):
        super().__init__()
        ConcurrentWidgetMixin.__init__(self)

        self._res: Optional[resolwe.resapi.ResolweAPI] = None

        # Store collection ID from currently selected row
        self.selected_collection_id: Optional[str] = None
        # Store data output options
        self.data_output_options: Optional[DataOutputOptions] = None
        # Cache output data table
        self.data_table: Optional[Table] = None
        # Cache clinical metadata
        self.clinical_metadata: Optional[Table] = None

        # Control area
        self.info_box = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info", margin=3),
            'No data on output.')

        self.exp_type_combo = gui.comboBox(
            self.controlArea,
            self,
            'exp_type',
            label='Expression Type',
            callback=self.on_output_option_changed)
        self.proc_slug_combo = gui.comboBox(
            self.controlArea,
            self,
            'proc_slug',
            label='Process Name',
            callback=self.on_output_option_changed)
        self.exp_source_combo = gui.comboBox(
            self.controlArea,
            self,
            'exp_source',
            label='Expression source',
            callback=self.on_output_option_changed,
        )

        self.norm_component = NormalizationComponent(self, self.controlArea)
        self.norm_component.options_changed.connect(
            self.on_normalization_changed)

        box = gui.widgetBox(self.controlArea, 'Sample QC')
        gui.checkBox(box,
                     self,
                     'append_qc_data',
                     'Append QC data',
                     callback=self.on_output_option_changed)

        gui.rubber(self.controlArea)
        box = gui.widgetBox(self.controlArea, 'Sign in')
        self.user_info = gui.label(box, self, '')
        self.server_info = gui.label(box, self, '')

        box = gui.widgetBox(box, orientation=Qt.Horizontal)
        self.sign_in_btn = gui.button(box,
                                      self,
                                      'Sign In',
                                      callback=self.sign_in,
                                      autoDefault=False)
        self.sign_out_btn = gui.button(box,
                                       self,
                                       'Sign Out',
                                       callback=self.sign_out,
                                       autoDefault=False)

        self.commit_button = gui.auto_commit(self.controlArea,
                                             self,
                                             'auto_commit',
                                             '&Commit',
                                             box=False)
        self.commit_button.button.setAutoDefault(False)

        # Main area
        self.table_view = QTableView()
        self.table_view.setAlternatingRowColors(True)
        self.table_view.viewport().setMouseTracking(True)
        self.table_view.setShowGrid(False)
        self.table_view.verticalHeader().hide()
        self.table_view.horizontalHeader().setSectionResizeMode(
            QHeaderView.ResizeToContents)
        self.table_view.horizontalHeader().setStretchLastSection(True)
        self.table_view.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.table_view.setSelectionMode(QAbstractItemView.SingleSelection)
        # self.table_view.setStyleSheet('QTableView::item:selected{background-color: palette(highlight); color: palette(highlightedText);};')

        self.model = GenialisExpressionsModel(self)
        self.model.setHorizontalHeaderLabels(TableHeader.labels())
        self.table_view.setModel(self.model)
        self.table_view.selectionModel().selectionChanged.connect(
            self.on_selection_changed)

        self.filter_component = CollapsibleFilterComponent(self, self.mainArea)
        self.filter_component.options_changed.connect(self.on_filter_changed)
        self.mainArea.layout().addWidget(self.table_view)
        self.pagination_component = PaginationComponent(self, self.mainArea)
        self.pagination_component.options_changed.connect(
            self.update_collections_view)

        self.sign_in(silent=True)

    @property
    def res(self):
        return self._res

    @res.setter
    def res(self, value: resolwe.resapi.ResolweAPI):
        if isinstance(value, resolwe.resapi.ResolweAPI):
            self._res = value
            self.update_user_status()
            self.update_collections_view()
            self.__invalidate()
            self.Outputs.table.send(None)

    def __invalidate(self):
        self.data_table = None
        self.selected_collection_id = None
        self.clinical_metadata = None

        self.data_output_options = None
        self.exp_type_combo.clear()
        self.proc_slug_combo.clear()
        self.exp_source_combo.clear()

        self.Outputs.table.send(None)
        self.Warning.no_expressions.clear()
        self.Warning.multiple_feature_type.clear()
        self.Warning.unexpected_feature_type.clear()
        self.Warning.no_data_objects.clear()
        self.info.set_output_summary(StateInfo.NoOutput)
        self.update_info_box()

    def update_user_status(self):
        user = self.res.get_currently_logged_user()

        if user:
            user_info = f"{user[0].get('first_name', '')} {user[0].get('last_name', '')}".strip(
            )
            user_info = f"User: {user_info if user_info else user[0].get('username', '')}"
            self.sign_in_btn.setEnabled(False)
            self.sign_out_btn.setEnabled(True)
        else:
            user_info = 'User: Anonymous'
            self.sign_in_btn.setEnabled(True)
            self.sign_out_btn.setEnabled(False)

        self.user_info.setText(user_info)
        self.server_info.setText(f'Server: {self.res.url[8:]}')

    def update_info_box(self):
        if self.data_table:
            total_genes = len(self.data_table.domain.attributes)
            known_genes = len([
                col for col in self.data_table.domain.attributes
                if len(col.attributes)
            ])

            info_text = ('{} genes on output\n'
                         '{} genes match Entrez database\n'
                         '{} genes with match conflicts\n'.format(
                             total_genes, known_genes,
                             total_genes - known_genes))

        else:
            info_text = 'No data on output.'

        self.info_box.setText(info_text)

    def sign_in(self, silent=False):
        dialog = SignIn(self, server_type=resolwe.RESOLWE_PLATFORM)

        if silent:
            dialog.sign_in()
            if dialog.resolwe_instance is not None:
                self.res = dialog.resolwe_instance
            else:
                self.res = resolwe.connect(
                    url=resolwe.resapi.DEFAULT_URL,
                    server_type=resolwe.RESOLWE_PLATFORM)

        if not silent and dialog.exec_():
            self.res = dialog.resolwe_instance

    def sign_out(self):
        # Use public credentials when user signs out
        self.res = resolwe.connect(url=resolwe.resapi.DEFAULT_URL,
                                   server_type=resolwe.RESOLWE_PLATFORM)
        # Remove username and password
        cm = get_credential_manager(resolwe.RESOLWE_PLATFORM)
        if cm.username:
            del cm.username
        if cm.password:
            del cm.password

    def on_filter_changed(self):
        self.pagination_component.reset_pagination()
        self.update_collections_view()

    def get_query_parameters(self) -> Dict[str, str]:
        params = {
            'limit':
            ItemsPerPage.values()[self.pagination_component.items_per_page],
            'offset':
            self.pagination_component.offset,
            'ordering':
            SortBy.values()[self.filter_component.sort_by],
        }

        if self.filter_component.filter_by_full_text:
            params.update({'text': self.filter_component.filter_by_full_text})

        if self.filter_component.filter_by_name:
            params.update(
                {'name__icontains': self.filter_component.filter_by_name})

        if self.filter_component.filter_by_contrib:
            params.update(
                {'contributor_name': self.filter_component.filter_by_contrib})

        if self.filter_component.filter_by_owner:
            params.update(
                {'owners_name': self.filter_component.filter_by_owner})

        last_modified = FilterByDateModified.values()[
            self.filter_component.filter_by_modified]
        if last_modified:
            params.update({'modified__gte': last_modified.isoformat()})

        return params

    def get_collections(self) -> Tuple[Dict[str, str], Dict[str, str]]:
        # Get response from the server
        collections = self.res.get_collections(**self.get_query_parameters())
        # Loop trough collections and store ids
        collection_ids = [
            collection['id'] for collection in collections.get('results', [])
        ]
        # Get species by collection ids
        collection_to_species = self.res.get_species(collection_ids)

        return collections, collection_to_species

    def update_collections_view(self):
        collections, collection_to_species = self.get_collections()

        # Pass the results to data model
        self.model.set_data(collections.get('results', []),
                            collection_to_species)
        self.table_view.setItemDelegateForColumn(
            TableHeader.id, gui.LinkStyledItemDelegate(self.table_view))
        self.table_view.setColumnHidden(TableHeader.slug, True)
        self.table_view.setColumnHidden(TableHeader.tags, True)

        # Check pagination parameters and emit pagination_availability signal
        next_page = True if collections.get('next') else False
        previous_page = True if collections.get('previous') else False
        self.pagination_availability.emit(next_page, previous_page)

    def normalize(self, table: Table) -> Optional[Table]:
        if not table:
            return

        if self.norm_component.quantile_norm:
            table = QuantileNormalization()(table)

        if self.norm_component.log_norm:
            table = LogarithmicScale()(table)

        if self.norm_component.z_score_norm:
            table = ZScore(axis=self.norm_component.z_score_axis)(table)

        if self.norm_component.quantile_transform:
            axis = self.norm_component.quantile_transform_axis
            quantiles = table.X.shape[int(not axis)]
            distribution = QuantileTransformDist.values()[
                self.norm_component.quantile_transform_dist]
            table = QuantileTransform(axis=axis,
                                      n_quantiles=quantiles,
                                      output_distribution=distribution)(table)

        return table

    def commit(self):
        self.Warning.no_data_objects.clear()
        self.cancel()
        self.start(self.runner)

    def on_output_option_changed(self):
        self.data_table = None
        self.commit()

    def on_clinical_data_changed(self):
        self.clinical_metadata = self.fetch_clinical_metadata()
        self.commit()

    def on_normalization_changed(self):
        self.commit()

    def on_selection_changed(self):
        self.__invalidate()

        collection_id: str = self.get_selected_row_data(TableHeader.id)
        if not collection_id:
            return

        self.selected_collection_id = collection_id
        data_objects = self.res.get_expression_data_objects(collection_id)
        self.data_output_options = available_data_output_options(data_objects)

        self.exp_type_combo.addItems(
            exp_name
            for _, exp_name in self.data_output_options.expression_type)
        if self.exp_type >= len(self.data_output_options.expression_type):
            self.exp_type = 0
        self.exp_type_combo.setCurrentIndex(self.exp_type)

        self.proc_slug_combo.addItems(
            proc_name for _, proc_name in self.data_output_options.process)
        if self.proc_slug >= len(self.data_output_options.process):
            self.proc_slug = 0
        self.proc_slug_combo.setCurrentIndex(self.proc_slug)

        self.exp_source_combo.addItems(
            self.data_output_options.expression_sources)
        if self.exp_source >= len(self.data_output_options.expression_sources):
            self.exp_source = 0
        self.exp_source_combo.setCurrentIndex(self.exp_source)

        if not data_objects:
            self.Warning.no_expressions()
            return

        # Note: This here is to handle an edge case where we get
        #       different 'feature_type' data object in a collection.
        #       For now we raise a warning, but in the future we should
        #       discuss about how to properly handle different types of features.
        feature_types = {data.output['feature_type'] for data in data_objects}

        if len(feature_types) == 1 and 'gene' not in feature_types:
            self.Warning.unexpected_feature_type(feature_types.pop())
            # self.data_objects = []
            return

        if len(feature_types) > 1:
            self.Warning.multiple_feature_type()
            # self.data_objects = []
            return

        self.on_output_option_changed()

    def get_selected_row_data(self, column: int) -> Optional[str]:
        selection_model = self.table_view.selectionModel()
        rows = selection_model.selectedRows(column=column)
        if not rows:
            return

        return rows[0].data()

    def on_done(self, table: Table):
        if table:
            samples, genes = table.X.shape
            self.info.set_output_summary(f'Samples: {samples} Genes: {genes}')
            self.update_info_box()
            self.Outputs.table.send(table)

    def on_exception(self, ex):
        # if isinstance(ex, ResolweDataObjectsNotFound):
        #     self.Warning.no_data_objects()
        #     self.Outputs.table.send(None)
        #     self.data_table = None
        #     self.info.set_output_summary(StateInfo.NoOutput)
        #     self.update_info_box()
        # else:
        raise ex

    def on_partial_result(self, result: Any) -> None:
        pass

    def onDeleteWidget(self):
        self.shutdown()
        super().onDeleteWidget()

    def sizeHint(self):
        return QSize(1280, 620)

    def runner(self, state: TaskState) -> Table:
        exp_type = self.data_output_options.expression_type[self.exp_type].type
        exp_source = self.data_output_options.expression_sources[
            self.exp_source]
        proc_slug = self.data_output_options.process[self.proc_slug].slug
        collection_id = self.selected_collection_id

        table = self.data_table
        progress_steps_download = iter(np.linspace(0, 50, 2))

        def callback(i: float, status=""):
            state.set_progress_value(i * 100)
            if status:
                state.set_status(status)
            if state.is_interruption_requested():
                raise Exception

        if not table:
            collection = self.res.get_collection_by_id(collection_id)
            coll_table = resdk.tables.RNATables(
                collection,
                expression_source=exp_source,
                expression_process_slug=proc_slug,
                progress_callable=wrap_callback(callback, end=0.5),
            )
            species = coll_table._data[0].output['species']
            sample = coll_table._samples[0]

            state.set_status('Downloading ...')
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            df_exp = coll_table.exp if exp_type != 'rc' else coll_table.rc
            df_exp = df_exp.rename(index=coll_table.readable_index)
            df_metas = coll_table.meta
            df_metas = df_metas.rename(index=coll_table.readable_index)
            df_qc = None
            if self.append_qc_data:
                # TODO: check if there is a way to detect if collection
                #       table contains QC data
                try:
                    df_qc = coll_table.qc
                    df_qc = df_qc.rename(index=coll_table.readable_index)
                except ValueError:
                    pass
            loop.close()

            state.set_status('To data table ...')

            duplicates = {
                item
                for item, count in Counter([
                    label.split('.')[1]
                    for label in df_metas.columns.to_list() if '.' in label
                ]).items() if count > 1
            }

            # what happens if there is more nested sections?
            section_name_to_label = {
                section['name']: section['label']
                for section in sample.descriptor_schema.schema
            }

            column_labels = {}
            for field_schema, fields, path in iterate_schema(
                    sample.descriptor, sample.descriptor_schema.schema,
                    path=''):
                path = path[1:]  # this is ugly, but cant go around it
                if path not in df_metas.columns:
                    continue
                label = field_schema['label']
                section_name, field_name = path.split('.')
                column_labels[path] = (
                    label if field_name not in duplicates else
                    f'{section_name_to_label[section_name]} - {label}')

            df_exp = df_exp.reset_index(drop=True)
            df_metas = df_metas.astype('object')
            df_metas = df_metas.fillna(np.nan)
            df_metas = df_metas.replace('nan', np.nan)
            df_metas = df_metas.rename(columns=column_labels)
            if df_qc is not None:
                df_metas = pd.merge(df_metas,
                                    df_qc,
                                    left_index=True,
                                    right_index=True)

            xym, domain_metas = vars_from_df(df_metas)
            x, _, m = xym
            x_metas = np.hstack((x, m))
            attrs = [ContinuousVariable(col) for col in df_exp.columns]
            metas = domain_metas.attributes + domain_metas.metas
            domain = Domain(attrs, metas=metas)
            table = Table(domain, df_exp.to_numpy(), metas=x_metas)
            state.set_progress_value(next(progress_steps_download))

            state.set_status('Matching genes ...')
            progress_steps_gm = iter(
                np.linspace(50, 99, len(coll_table.gene_ids)))

            def gm_callback():
                state.set_progress_value(next(progress_steps_gm))

            tax_id = species_name_to_taxid(species)
            gm = GeneMatcher(tax_id, progress_callback=gm_callback)
            table = gm.match_table_attributes(table, rename=True)
            table.attributes[TableAnnotation.tax_id] = tax_id
            table.attributes[TableAnnotation.gene_as_attr_name] = True
            table.attributes[TableAnnotation.gene_id_attribute] = 'Entrez ID'
            self.data_table = table

        state.set_status('Normalizing ...')
        table = self.normalize(table)
        state.set_progress_value(100)

        return table
Exemple #16
0
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results", Results)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")
        wrong_targets = Msg(
            "Some model(s) predict a different target (see more ...)\n{}")

    class Error(OWWidget.Error):
        predictor_failed = Msg("Some predictor(s) failed (see more ...)\n{}")
        scorer_failed = Msg("Some scorer(s) failed (see more ...)\n{}")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: List of selected class value indices in the `class_values` list
    selected_classes = settings.ContextSetting([])

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Orange.data.Table]
        self.predictors = {}  # type: Dict[object, PredictorSlot]
        self.class_values = []  # type: List[str]
        self._delegates = []

        gui.listBox(self.controlArea,
                    self,
                    "selected_classes",
                    "class_values",
                    box="Show probabibilities for",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False,
                    sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred))
        gui.rubber(self.controlArea)
        gui.button(self.controlArea,
                   self,
                   "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                          horizontalScrollMode=QTableView.ScrollPerPixel,
                          selectionMode=QTableView.NoSelection,
                          focusPolicy=Qt.StrongFocus)
        self.dataview = TableView(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                  **table_opts)
        self.predictionsview = TableView(
            sortingEnabled=True,
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            **table_opts)
        self.dataview.verticalHeader().hide()
        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()
        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter = QSplitter(orientation=Qt.Horizontal,
                                  childrenCollapsible=False,
                                  handleWidth=2)
        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter = gui.vBox(self.mainArea)
        self.vsplitter.layout().addWidget(self.splitter)
        self.vsplitter.layout().addWidget(self.score_table.view)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.Warning.empty_data(shown=data is not None and not data)
        self.data = data
        if not data:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)

        self._invalidate_predictions()

    @property
    def class_var(self):
        return self.data and self.data.domain.class_var

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = PredictorSlot(predictor, predictor.name,
                                                None)

    def _set_class_values(self):
        class_values = []
        for slot in self.predictors.values():
            class_var = slot.predictor.domain.class_var
            if class_var and class_var.is_discrete:
                for value in class_var.values:
                    if value not in class_values:
                        class_values.append(value)

        if self.class_var and self.class_var.is_discrete:
            values = self.class_var.values
            self.class_values = sorted(class_values,
                                       key=lambda val: val not in values)
            self.selected_classes = [
                i for i, name in enumerate(class_values) if name in values
            ]
        else:
            self.class_values = class_values  # This assignment updates listview
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_values()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._update_info()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        if self.class_var:
            domain = self.data.domain
            classless_data = self.data.transform(
                Domain(domain.attributes, None, domain.metas))
        else:
            classless_data = self.data

        for inputid, slot in self.predictors.items():
            if isinstance(slot.results, Results):
                continue

            predictor = slot.predictor
            try:
                if predictor.domain.class_var.is_discrete:
                    pred, prob = predictor(classless_data, Model.ValueProbs)
                else:
                    pred = predictor(classless_data, Model.Value)
                    prob = numpy.zeros((len(pred), 0))
            except (ValueError, DomainTransformationError) as err:
                self.predictors[inputid] = \
                    slot._replace(results=f"{predictor.name}: {err}")
                continue

            results = Results()
            results.data = self.data
            results.domain = self.data.domain
            results.row_indices = numpy.arange(len(self.data))
            results.folds = (Ellipsis, )
            results.actual = self.data.Y
            results.unmapped_probabilities = prob
            results.unmapped_predicted = pred
            results.probabilities = results.predicted = None
            self.predictors[inputid] = slot._replace(results=results)

            target = predictor.domain.class_var
            if target != self.class_var:
                continue

            if target is not self.class_var and target.is_discrete:
                backmappers, n_values = predictor.get_backmappers(self.data)
                prob = predictor.backmap_probs(prob, n_values, backmappers)
                pred = predictor.backmap_value(pred, prob, n_values,
                                               backmappers)
            results.predicted = pred.reshape((1, len(self.data)))
            results.probabilities = prob.reshape((1, ) + prob.shape)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()
        scorers = usable_scorers(self.class_var) if self.class_var else []
        self.score_table.update_header(scorers)
        errors = []
        for inputid, pred in self.predictors.items():
            results = self.predictors[inputid].results
            if not isinstance(results, Results) or results.predicted is None:
                continue
            row = [
                QStandardItem(learner_name(pred.predictor)),
                QStandardItem("N/A"),
                QStandardItem("N/A")
            ]
            for scorer in scorers:
                item = QStandardItem()
                try:
                    score = scorer_caller(scorer, results)()[0]
                    item.setText(f"{score:.3f}")
                except Exception as exc:  # pylint: disable=broad-except
                    item.setToolTip(str(exc))
                    if scorer.name in self.score_table.shown_scores:
                        errors.append(str(exc))
                row.append(item)
            self.score_table.model.appendRow(row)

        view = self.score_table.view
        if model.rowCount():
            view.setVisible(True)
            view.ensurePolished()
            view.setFixedHeight(5 + view.horizontalHeader().height() +
                                view.verticalHeader().sectionSize(0) *
                                model.rowCount())
        else:
            view.setVisible(False)

        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(f"- {p.predictor.name}: {p.results}"
                           for p in self.predictors.values()
                           if isinstance(p.results, str) and p.results)
        self.Error.predictor_failed(errors, shown=bool(errors))

        if self.class_var:
            inv_targets = "\n".join(
                f"- {pred.name} predicts '{pred.domain.class_var.name}'"
                for pred in (p.predictor for p in self.predictors.values()
                             if isinstance(p.results, Results)
                             and p.results.probabilities is None))
            self.Warning.wrong_targets(inv_targets, shown=bool(inv_targets))
        else:
            self.Warning.wrong_targets.clear()

    def _update_info(self):
        n_predictors = len(self.predictors)
        if not self.data and not n_predictors:
            self.info.set_input_summary(self.info.NoInput)
            return

        n_valid = len(self._non_errored_predictors())
        summary = str(len(self.data)) if self.data else "0"
        details = f"{len(self.data)} instances" if self.data else "No data"
        details += f"\n{n_predictors} models" if n_predictors else "No models"
        if n_valid != n_predictors:
            details += f" ({n_predictors - n_valid} failed)"
        self.info.set_input_summary(summary, details)

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _non_errored_predictors(self):
        return [
            p for p in self.predictors.values()
            if isinstance(p.results, Results)
        ]

    def _update_predictions_model(self):
        results = []
        headers = []
        for p in self._non_errored_predictors():
            values = p.results.unmapped_predicted
            target = p.predictor.domain.class_var
            if target.is_discrete:
                prob = p.results.unmapped_probabilities
                values = [Value(target, v) for v in values]
            else:
                prob = numpy.zeros((len(values), 0))
            results.append((values, prob))
            headers.append(p.predictor.name)

        if results:
            results = list(zip(*(zip(*res) for res in results)))
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setModel(predmodel)
        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_data_sort_order(self):
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort([
                    predmodel.mapToSource(predmodel.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        selected = {self.class_values[i] for i in self.selected_classes}
        self._delegates.clear()
        for col, slot in enumerate(self.predictors.values()):
            target = slot.predictor.domain.class_var
            shown_probs = () if target.is_continuous else \
                [i for i, name in enumerate(target.values) if name in selected]
            delegate = PredictionsItemDelegate(target, shown_probs)
            # QAbstractItemView does not take ownership of delegates, so we must
            self._delegates.append(delegate)
            self.predictionsview.setItemDelegateForColumn(col, delegate)
            self.predictionsview.setColumnHidden(col, False)

        self.predictionsview.resizeColumnsToContents()
        self._update_spliter()

    def _update_spliter(self):
        if not self.data:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        w = width(self.predictionsview) + 4
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([w, w1 + w2 - w])

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = [
            p for p in self._non_errored_predictors()
            if p.results.predicted is not None
        ]
        if not slots:
            self.Outputs.evaluation_results.send(None)
            return

        nanmask = numpy.isnan(self.data.get_column_view(self.class_var)[0])
        data = self.data[~nanmask]
        results = Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if self.class_var and self.class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        if not self.data:
            self.Outputs.predictions.send(None)
            return

        newmetas = []
        newcolumns = []
        for slot in self._non_errored_predictors():
            if slot.predictor.domain.class_var.is_discrete:
                self._add_classification_out_columns(slot, newmetas,
                                                     newcolumns)
            else:
                self._add_regression_out_columns(slot, newmetas, newcolumns)

        attrs = list(self.data.domain.attributes)
        metas = list(self.data.domain.metas) + newmetas
        domain = Orange.data.Domain(attrs, self.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns
        self.Outputs.predictions.send(predictions)

    @staticmethod
    def _add_classification_out_columns(slot, newmetas, newcolumns):
        # Mapped or unmapped predictions?!
        # Or provide a checkbox so the user decides?
        pred = slot.predictor
        name = pred.name
        values = pred.domain.class_var.values
        newmetas.append(DiscreteVariable(name=name, values=values))
        newcolumns.append(slot.results.unmapped_predicted.reshape(-1, 1))
        newmetas += [
            ContinuousVariable(name=f"{name} ({value})") for value in values
        ]
        newcolumns.append(slot.results.unmapped_probabilities)

    @staticmethod
    def _add_regression_out_columns(slot, newmetas, newcolumns):
        newmetas.append(ContinuousVariable(name=slot.predictor.name))
        newcolumns.append(slot.results.unmapped_predicted.reshape((-1, 1)))

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_model = self.predictionsview.model()

            # use ItemDelegate to style prediction values
            style = lambda x: self.predictionsview.itemDelegate().displayText(
                x, QLocale())

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [style(predictions_model.data(predictions_model.index(i, j)))
                       for j in range(predictions_model.columnCount())] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data:
            text = self.infolabel.text().replace('\n', '<br>')
            if self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)
Exemple #17
0
class OWFreeViz(OWAnchorProjectionWidget):
    MAX_ITERATIONS = 1000
    MAX_INSTANCES = 10000

    name = "FreeViz"
    description = "Displays FreeViz projection"
    icon = "icons/Freeviz.svg"
    priority = 240
    keywords = ["viz"]

    settings_version = 3
    initialization = settings.Setting(InitType.Circular)
    GRAPH_CLASS = OWFreeVizGraph
    graph = settings.SettingProvider(OWFreeVizGraph)

    class Error(OWAnchorProjectionWidget.Error):
        no_class_var = widget.Msg("Data has no target variable")
        not_enough_class_vars = widget.Msg(
            "Target variable is not at least binary")
        features_exceeds_instances = widget.Msg(
            "Number of features exceeds the number of instances.")
        too_many_data_instances = widget.Msg("Data is too large.")
        constant_data = widget.Msg("All data columns are constant.")
        not_enough_features = widget.Msg("At least two features are required")

    class Warning(OWAnchorProjectionWidget.Warning):
        removed_features = widget.Msg("Categorical features with more than"
                                      " two values are not shown.")

    def __init__(self):
        super().__init__()
        self._loop = AsyncUpdateLoop(parent=self)
        self._loop.yielded.connect(self.__set_projection)
        self._loop.finished.connect(self.__freeviz_finished)
        self._loop.raised.connect(self.__on_error)

    def _add_controls(self):
        self.__add_controls_start_box()
        super()._add_controls()
        self.graph.gui.add_control(self._effects_box,
                                   gui.hSlider,
                                   "Hide radius:",
                                   master=self.graph,
                                   value="hide_radius",
                                   minValue=0,
                                   maxValue=100,
                                   step=10,
                                   createLabel=False,
                                   callback=self.__radius_slider_changed)

    def __add_controls_start_box(self):
        box = gui.vBox(self.controlArea, box=True)
        gui.comboBox(box,
                     self,
                     "initialization",
                     label="Initialization:",
                     items=InitType.items(),
                     orientation=Qt.Horizontal,
                     labelWidth=90,
                     callback=self.__init_combo_changed)
        self.btn_start = gui.button(box,
                                    self,
                                    "Optimize",
                                    self.__toggle_start,
                                    enabled=False)

    @property
    def effective_variables(self):
        return [
            a for a in self.data.domain.attributes
            if a.is_continuous or a.is_discrete and len(a.values) == 2
        ]

    def __radius_slider_changed(self):
        self.graph.update_radius()

    def __toggle_start(self):
        if self._loop.isRunning():
            self._loop.cancel()
            self.btn_start.setText("Optimize")
            self.progressBarFinished(processEvents=False)
        else:
            self._start()

    def __init_combo_changed(self):
        if self.data is None:
            return
        running = self._loop.isRunning()
        if running:
            self._loop.cancel()
        self.init_projection()
        self.graph.update_coordinates()
        self.commit()
        if running:
            self._start()

    def _start(self):
        def update_freeviz(anchors):
            while True:
                self.projection = self.projector(self.effective_data)
                _anchors = self.projector.components_.T
                self.projector.initial = _anchors
                yield _anchors
                if np.allclose(anchors, _anchors, rtol=1e-5, atol=1e-4):
                    return
                anchors = _anchors

        self.graph.set_sample_size(self.SAMPLE_SIZE)
        self._loop.setCoroutine(update_freeviz(self.projector.components_.T))
        self.btn_start.setText("Stop")
        self.progressBarInit()
        self.setBlocking(True)
        self.setStatusMessage("Optimizing")

    def __set_projection(self, _):
        # Set/update the projection matrix and coordinate embeddings
        self.progressBarAdvance(100. / self.MAX_ITERATIONS)
        self.graph.update_coordinates()

    def __freeviz_finished(self):
        self.graph.set_sample_size(None)
        self.btn_start.setText("Optimize")
        self.setStatusMessage("")
        self.setBlocking(False)
        self.progressBarFinished()
        self.commit()

    def __on_error(self, err):
        sys.excepthook(type(err), err, getattr(err, "__traceback__"))

    def check_data(self):
        def error(err):
            err()
            self.data = None

        super().check_data()
        if self.data is not None:
            class_var, domain = self.data.domain.class_var, self.data.domain
            if class_var is None:
                error(self.Error.no_class_var)
            elif class_var.is_discrete and len(np.unique(self.data.Y)) < 2:
                error(self.Error.not_enough_class_vars)
            elif len(self.data.domain.attributes) < 2:
                error(self.Error.not_enough_features)
            elif len(self.data.domain.attributes) > self.data.X.shape[0]:
                error(self.Error.features_exceeds_instances)
            elif not np.sum(np.std(self.data.X, axis=0)):
                error(self.Error.constant_data)
            elif np.sum(self.valid_data) > self.MAX_INSTANCES:
                error(self.Error.too_many_data_instances)
            else:
                if len(self.effective_variables) < len(domain.attributes):
                    self.Warning.removed_features()
        self.btn_start.setEnabled(self.data is not None)

    def set_data(self, data):
        super().set_data(data)
        if self.data is not None:
            self.init_projection()

    def init_projection(self):
        anchors = FreeViz.init_radial(len(self.effective_variables)) \
            if self.initialization == InitType.Circular \
            else FreeViz.init_random(len(self.effective_variables), 2)
        self.projector = FreeViz(scale=False,
                                 center=False,
                                 initial=anchors,
                                 maxiter=10)
        data = self.projector.preprocess(self.effective_data)
        self.projector.domain = data.domain
        self.projector.components_ = anchors.T
        self.projection = FreeVizModel(self.projector, self.projector.domain)
        self.projection.pre_domain = data.domain
        self.projection.name = self.projector.name

    def get_coordinates_data(self):
        embedding = self.get_embedding()
        if embedding is None:
            return None, None
        valid_emb = embedding[self.valid_data]
        return valid_emb.T / (np.max(np.linalg.norm(valid_emb, axis=1)) or 1)

    def _manual_move(self, anchor_idx, x, y):
        self.projector.initial[anchor_idx] = [x, y]
        super()._manual_move(anchor_idx, x, y)

    def clear(self):
        super().clear()
        self._loop.cancel()

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            if "radius" in _settings:
                _settings["graph"]["hide_radius"] = _settings["radius"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 3:
            values = context.values
            values["attr_color"] = values["graph"]["attr_color"]
            values["attr_size"] = values["graph"]["attr_size"]
            values["attr_shape"] = values["graph"]["attr_shape"]
            values["attr_label"] = values["graph"]["attr_label"]
Exemple #18
0
class OW1ka(widget.OWWidget):
    name = "EnKlik Anketa"
    description = "Import data from EnKlikAnketa (1ka.si) public URL."
    icon = "icons/1ka.svg"
    priority = 200

    class Outputs:
        data = Output("Data", Table)

    want_main_area = False
    resizing_enabled = False

    settingsHandler = settings.PerfectDomainContextHandler(
        match_values=settings.PerfectDomainContextHandler.MATCH_VALUES_ALL)

    recent = settings.Setting([])
    reload_idx = settings.Setting(0)
    autocommit = settings.Setting(True)
    domain_editor = settings.SettingProvider(DomainEditor)

    UserAdviceMessages = [
        widget.Message(
            'You can import data from public links to 1ka surveys results. '
            'Click to learn more on how to get a shareable public link URL for '
            '1ka surveys that you manage.',
            'public-link',
            icon=widget.Message.Information,
            moreurl=
            'http://english.1ka.si/db/24/468/Guides/Public_link_to_access_data_and_analysis/'
        ),
    ]

    class Error(widget.OWWidget.Error):
        net_error = widget.Msg(
            "Couldn't load data: {}. Ensure network connection, firewall ...")
        parse_error = widget.Msg(
            "Couldn't parse data: {}. Ensure well-formatted data or submit a bug report."
        )
        invalid_url = widget.Msg(
            'Invalid URL. Public shareable link should match: ' +
            VALID_URL_HELP)
        data_is_anal = widget.Msg(
            "The provided URL is a public link to 'Analysis'. Need public link to 'Data'."
        )

    class Information(widget.OWWidget.Information):
        response_data_empty = widget.Msg(
            'Response data is empty. Get some responses first.')

    def __init__(self):
        super().__init__()
        self.table = None
        self._html = None

        def _loadFinished(is_ok):
            if is_ok:
                QTimer.singleShot(
                    1, lambda: setattr(self, '_html', self.webview.html()))

        self.webview = WebviewWidget(loadFinished=_loadFinished)

        vb = gui.vBox(self.controlArea, 'Import Data')
        hb = gui.hBox(vb)
        self.combo = combo = URLComboBox(
            hb,
            self.recent,
            editable=True,
            minimumWidth=400,
            insertPolicy=QComboBox.InsertAtTop,
            toolTip='Format: ' + VALID_URL_HELP,
            editTextChanged=self.is_valid_url,
            # Indirect via QTimer because calling wait() -> processEvents,
            # while our currentIndexChanged event hadn't yet finished.
            # Avoids calling handler twice.
            currentIndexChanged=lambda: QTimer.singleShot(1, self.load_url))
        hb.layout().addWidget(QLabel('Public link URL:', hb))
        hb.layout().addWidget(combo)
        hb.layout().setStretch(1, 2)

        RELOAD_TIMES = (
            ('No reload', ),
            ('5 s', 5000),
            ('10 s', 10000),
            ('30 s', 30000),
            ('1 min', 60 * 1000),
            ('2 min', 2 * 60 * 1000),
            ('5 min', 5 * 60 * 1000),
        )

        reload_timer = QTimer(self,
                              timeout=lambda: self.load_url(from_reload=True))

        def _on_reload_changed():
            if self.reload_idx == 0:
                reload_timer.stop()
                return
            reload_timer.start(RELOAD_TIMES[self.reload_idx][1])

        gui.comboBox(vb,
                     self,
                     'reload_idx',
                     label='Reload every:',
                     orientation=Qt.Horizontal,
                     items=[i[0] for i in RELOAD_TIMES],
                     callback=_on_reload_changed)

        box = gui.widgetBox(self.controlArea, "Columns (Double-click to edit)")
        self.domain_editor = DomainEditor(self)
        editor_model = self.domain_editor.model()

        def editorDataChanged():
            self.apply_domain_edit()
            self.commit()

        editor_model.dataChanged.connect(editorDataChanged)
        box.layout().addWidget(self.domain_editor)

        box = gui.widgetBox(self.controlArea, "Info", addSpace=True)
        info = self.data_info = gui.widgetLabel(box, '')
        info.setWordWrap(True)

        self.controlArea.layout().addStretch(1)
        gui.auto_commit(self.controlArea, self, 'autocommit', label='Commit')

        self.set_info()

    def set_combo_items(self):
        self.combo.clear()
        for sheet in self.recent:
            self.combo.addItem(sheet.name, sheet.url)

    def commit(self):
        self.Outputs.data.send(self.table)

    def is_valid_url(self, url):
        if is_valid_url(url):
            self.Error.invalid_url.clear()
            return True
        self.Error.invalid_url()
        QToolTip.showText(self.combo.mapToGlobal(QPoint(0, 0)),
                          self.combo.toolTip())

    def load_url(self, from_reload=False):
        self.closeContext()
        self.domain_editor.set_domain(None)

        url = self.combo.currentText()
        if not self.is_valid_url(url):
            self.table = None
            self.commit()
            return

        if url not in self.recent:
            self.recent.insert(0, url)

        prev_table = self.table
        with self.progressBar(3) as progress:
            try:
                self._html = None
                self.webview.setUrl(url)
                wait(until=lambda: self._html is not None)
                progress.advance()
                # Wait some seconds for discrete labels to have loaded via AJAX,
                # then re-query HTML.
                # *Webview.loadFinished doesn't guarantee it sufficiently
                try:
                    wait(until=lambda: False, timeout=1200)
                except TimeoutError:
                    pass
                progress.advance()
                html = self.webview.html()
            except Exception as e:
                log.exception("Couldn't load data from: %s", url)
                self.Error.net_error(try_(lambda: e.args[0], ''))
                self.table = None
            else:
                self.Error.clear()
                self.Information.clear()
                self.table = None
                try:
                    table = self.table = self.table_from_html(html)
                except DataEmptyError:
                    self.Information.response_data_empty()
                except DataIsAnalError:
                    self.Error.data_is_anal()
                except Exception as e:
                    log.exception('Parsing error: %s', url)
                    self.Error.parse_error(try_(lambda: e.args[0], ''))
                else:
                    self.openContext(table.domain)
                    self.combo.setTitleFor(self.combo.currentIndex(),
                                           table.name)

        def _equal(data1, data2):
            NAN = float('nan')
            return (try_(lambda: data1.checksum(),
                         NAN) == try_(lambda: data2.checksum(), NAN))

        self._orig_table = self.table
        self.apply_domain_edit()

        if not (from_reload and _equal(prev_table, self.table)):
            self.commit()

    def apply_domain_edit(self):
        data = self._orig_table
        if data is None:
            self.set_info()
            return

        domain, cols = self.domain_editor.get_domain(data.domain, data)

        # Copied verbatim from OWFile
        if not (domain.variables or domain.metas):
            table = None
        else:
            X, y, m = cols
            table = Table.from_numpy(domain, X, y, m, data.W)
            table.name = data.name
            table.ids = np.array(data.ids)
            table.attributes = getattr(data, 'attributes', {})

        self.table = table
        self.set_info()

    DATETIME_VAR = 'Paradata (insert)'

    def table_from_html(self, html):
        soup = BeautifulSoup(html, 'html.parser')
        try:
            html_table = soup.find_all('table')[-1]
        except IndexError:
            raise DataEmptyError

        if '<h2>Anal' in html or 'div_analiza_' in html:
            raise DataIsAnalError

        def _header_row_strings(row):
            return chain.from_iterable(
                repeat(th.get_text(), int(th.get('colspan') or 1)) for th in
                html_table.select('thead tr:nth-of-type(%d) th[title]' % row))

        # self.DATETIME_VAR (available when Paradata is enabled in 1ka UI)
        # should match this variable name format
        header = [
            th1.rstrip(':') +
            ('' if th3 == th1 else ' ({})').format(th3.rstrip(':'))
            for th1, th3 in zip(_header_row_strings(1), _header_row_strings(3))
        ]
        values = [
            [
                (  # If no span, feature is a number or a text field
                    td.get_text() if td.span is None else
                    # If have span, it's a number, but if negative, replace with NaN
                    '' if td.contents[0].strip().startswith('-') else
                    # Else if span, the number is its code, but we want its value
                    td.span.get_text()[1:-1]) for td in tr.select('td')
                if 'data_uid' not in td.get('class', ())
            ] for tr in html_table.select('tbody tr')
        ]

        # Save parsed values into in-mem file for default values processing
        buffer = StringIO()
        writer = csv.writer(buffer, delimiter='\t')
        writer.writerow(header)
        writer.writerows(values)
        buffer.flush()
        buffer.seek(0)

        data = TabReader(buffer).read()

        title = soup.select('body h2:nth-of-type(1)')[0].get_text().split(
            ': ', maxsplit=1)[-1]
        data.name = title

        return data

    def set_info(self):
        data = self.table
        if data is None:
            self.data_info.setText('No spreadsheet loaded.')
            return
        text = "{}\n\n{} instance(s), {} feature(s), {} meta attribute(s)\n".format(
            data.name, len(data), len(data.domain.attributes),
            len(data.domain.metas))
        text += try_(
            lambda: '\nFirst entry: {}'
            '\nLast entry: {}'.format(data[0, self.DATETIME_VAR], data[
                -1, self.DATETIME_VAR]), '')
        self.data_info.setText(text)
Exemple #19
0
class OWLinearProjection(widget.OWWidget):
    name = "Linear Projection"
    description = "A multi-axis projection of data onto " \
                  "a two-dimensional plane."
    icon = "icons/LinearProjection.svg"
    priority = 240
    keywords = []

    selection_indices = settings.Setting(None, schema_only=True)

    class Inputs:
        data = Input("Data", Table, default=True)
        data_subset = Input("Data Subset", Table)
        projection = Input("Projection", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        components = Output("Components", Table)

    Placement = Enum("Placement",
                     dict(Circular=0, LDA=1, PCA=2, Projection=3),
                     type=int,
                     qualname="OWLinearProjection.Placement")

    Component_name = {
        Placement.Circular: "C",
        Placement.LDA: "LD",
        Placement.PCA: "PC"
    }
    Variable_name = {
        Placement.Circular: "circular",
        Placement.LDA: "lda",
        Placement.PCA: "pca",
        Placement.Projection: "projection"
    }

    jitter_sizes = [0, 0.1, 0.5, 1.0, 2.0]

    settings_version = 3
    settingsHandler = settings.DomainContextHandler()

    variable_state = settings.ContextSetting({})
    placement = settings.Setting(Placement.Circular)
    radius = settings.Setting(0)
    auto_commit = settings.Setting(True)

    resolution = 256

    graph = settings.SettingProvider(OWLinProjGraph)
    ReplotRequest = QEvent.registerEventType()
    vizrank = settings.SettingProvider(LinearProjectionVizRank)
    graph_name = "graph.plot_widget.plotItem"

    class Warning(widget.OWWidget.Warning):
        no_cont_features = widget.Msg("Plotting requires numeric features")
        not_enough_components = widget.Msg(
            "Input projection has less than 2 components")
        trivial_components = widget.Msg(
            "All components of the PCA are trivial (explain 0 variance). "
            "Input data is constant (or near constant).")

    class Error(widget.OWWidget.Error):
        proj_and_domain_match = widget.Msg(
            "Projection and Data domains do not match")
        no_valid_data = widget.Msg("No projection due to invalid data")

    def __init__(self):
        super().__init__()

        self.data = None
        self.projection = None
        self.subset_data = None
        self._subset_mask = None
        self._selection = None
        self.__replot_requested = False
        self.n_cont_var = 0
        #: Remember the saved state to restore
        self.__pending_selection_restore = self.selection_indices
        self.selection_indices = None

        self.variable_x = None
        self.variable_y = None

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWLinProjGraph(self,
                                    box,
                                    "Plot",
                                    view_box=LinProjInteractiveViewBox)
        box.layout().addWidget(self.graph.plot_widget)
        plot = self.graph.plot_widget

        SIZE_POLICY = (QSizePolicy.Minimum, QSizePolicy.Maximum)

        self.variables_selection = VariablesSelection()
        self.model_selected = VariableListModel(enable_dnd=True)
        self.model_other = VariableListModel(enable_dnd=True)
        self.variables_selection(self, self.model_selected, self.model_other)

        self.vizrank, self.btn_vizrank = LinearProjectionVizRank.add_vizrank(
            self.controlArea, self, "Suggest Features", self._vizrank)
        self.variables_selection.add_remove.layout().addWidget(
            self.btn_vizrank)

        box = gui.widgetBox(self.controlArea,
                            "Placement",
                            sizePolicy=SIZE_POLICY)
        self.radio_placement = gui.radioButtonsInBox(
            box,
            self,
            "placement",
            btnLabels=[
                "Circular Placement", "Linear Discriminant Analysis",
                "Principal Component Analysis", "Use input projection"
            ],
            callback=self._change_placement)

        self.viewbox = plot.getViewBox()
        self.replot = None

        g = self.graph.gui
        box = g.point_properties_box(self.controlArea)
        self.models = g.points_models
        g.add_widget(g.JitterSizeSlider, box)
        box.setSizePolicy(*SIZE_POLICY)

        box = gui.widgetBox(self.controlArea,
                            "Hide axes",
                            sizePolicy=SIZE_POLICY)
        self.rslider = gui.hSlider(box,
                                   self,
                                   "radius",
                                   minValue=0,
                                   maxValue=100,
                                   step=5,
                                   label="Radius",
                                   createLabel=False,
                                   ticks=True,
                                   callback=self.update_radius)
        self.rslider.setTickInterval(0)
        self.rslider.setPageStep(10)

        box = gui.vBox(self.controlArea, "Plot Properties")
        box.setSizePolicy(*SIZE_POLICY)

        g.add_widgets([
            g.ShowLegend, g.ToolTipShowsAll, g.ClassDensity,
            g.LabelOnlySelected
        ], box)

        box = self.graph.box_zoom_select(self.controlArea)
        box.setSizePolicy(*SIZE_POLICY)

        self.icons = gui.attributeIconDict

        p = self.graph.plot_widget.palette()
        self.graph.set_palette(p)
        gui.auto_commit(self.controlArea,
                        self,
                        "auto_commit",
                        "Send Selection",
                        auto_label="Send Automatically")
        self.graph.zoom_actions(self)

        self._new_plotdata()
        self._change_placement()
        self.graph.jitter_continuous = True

    def reset_graph_data(self):
        if self.data is not None:
            self.graph.rescale_data()
            self._update_graph(reset_view=True)

    def keyPressEvent(self, event):
        super().keyPressEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def keyReleaseEvent(self, event):
        super().keyReleaseEvent(event)
        self.graph.update_tooltip(event.modifiers())

    def _vizrank(self, attrs):
        self.variables_selection.display_none()
        self.model_selected[:] = attrs[:]
        self.model_other[:] = [
            var for var in self.model_other if var not in attrs
        ]

    def _change_placement(self):
        placement = self.placement
        p_Circular = self.Placement.Circular
        p_LDA = self.Placement.LDA
        self.variables_selection.set_enabled(placement in [p_Circular, p_LDA])
        self._vizrank_color_change()
        self.rslider.setEnabled(placement != p_Circular)
        self._setup_plot()
        self.commit()

    def _get_min_radius(self):
        return self.radius * np.max(np.linalg.norm(self.plotdata.axes,
                                                   axis=1)) / 100 + 1e-5

    def update_radius(self):
        # Update the anchor/axes visibility
        pd = self.plotdata
        assert pd is not None
        if pd.hidecircle is None:
            return
        min_radius = self._get_min_radius()
        for anchor, item in zip(pd.axes, pd.axisitems):
            item.setVisible(np.linalg.norm(anchor) > min_radius)
        pd.hidecircle.setRect(
            QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius))

    def _new_plotdata(self):
        self.plotdata = namespace(valid_mask=None,
                                  embedding_coords=None,
                                  axisitems=[],
                                  axes=[],
                                  variables=[],
                                  data=None,
                                  hidecircle=None)

    def _anchor_circle(self, variables):
        # minimum visible anchor radius (radius)
        min_radius = self._get_min_radius()
        axisitems = []
        for anchor, var in zip(self.plotdata.axes, variables[:]):
            axitem = AnchorItem(
                line=QLineF(0, 0, *anchor),
                text=var.name,
            )
            axitem.setVisible(np.linalg.norm(anchor) > min_radius)
            axitem.setPen(pg.mkPen((100, 100, 100)))
            axitem.setArrowVisible(True)
            self.viewbox.addItem(axitem)
            axisitems.append(axitem)

        self.plotdata.axisitems = axisitems
        if self.placement == self.Placement.Circular:
            return

        hidecircle = QGraphicsEllipseItem()
        hidecircle.setRect(
            QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius))

        _pen = QPen(Qt.lightGray, 1)
        _pen.setCosmetic(True)
        hidecircle.setPen(_pen)

        self.viewbox.addItem(hidecircle)
        self.plotdata.hidecircle = hidecircle

    def update_colors(self):
        self._vizrank_color_change()

    def clear(self):
        # Clear/reset the widget state
        self.data = None
        self.model_selected.clear()
        self.model_other.clear()
        self._clear_plot()
        self.selection_indices = None

    def _clear_plot(self):
        self.Warning.trivial_components.clear()
        for axisitem in self.plotdata.axisitems:
            self.viewbox.removeItem(axisitem)
        if self.plotdata.hidecircle:
            self.viewbox.removeItem(self.plotdata.hidecircle)
        self._new_plotdata()
        self.graph.hide_axes()

    def invalidate_plot(self):
        """
        Schedule a delayed replot.
        """
        if not self.__replot_requested:
            self.__replot_requested = True
            QApplication.postEvent(self, QEvent(self.ReplotRequest),
                                   Qt.LowEventPriority - 10)

    def init_attr_values(self):
        self.graph.set_domain(self.data)

    def _vizrank_color_change(self):
        is_enabled = False
        if self.data is None:
            self.btn_vizrank.setToolTip("There is no data.")
            return
        vars = [
            v
            for v in chain(self.data.domain.variables, self.data.domain.metas)
            if v.is_primitive and v is not self.graph.attr_color
        ]
        self.n_cont_var = len(vars)
        if self.placement not in [self.Placement.Circular, self.Placement.LDA]:
            msg = "Suggest Features works only for Circular and " \
                  "Linear Discriminant Analysis Projection"
        elif self.graph.attr_color is None:
            msg = "Color variable has to be selected"
        elif self.graph.attr_color.is_continuous and self.placement == self.Placement.LDA:
            msg = "Suggest Features does not work for Linear Discriminant Analysis Projection " \
                  "when continuous color variable is selected."
        elif len(vars) < 3:
            msg = "Not enough available continuous variables"
        else:
            is_enabled = True
            msg = ""
        self.btn_vizrank.setToolTip(msg)
        self.btn_vizrank.setEnabled(is_enabled)
        self.vizrank.stop_and_reset(is_enabled)

    @Inputs.projection
    def set_projection(self, projection):
        self.Warning.not_enough_components.clear()
        if projection and len(projection) < 2:
            self.Warning.not_enough_components()
            projection = None
        if projection is not None:
            self.placement = self.Placement.Projection
        self.projection = projection

    @Inputs.data
    def set_data(self, data):
        """
        Set the input dataset.

        Args:
            data (Orange.data.table): data instances
        """
        def sql(data):
            if isinstance(data, SqlTable):
                if data.approx_len() < 4000:
                    data = Table(data)
                else:
                    self.information("Data has been sampled")
                    data_sample = data.sample_time(1, no_cache=True)
                    data_sample.download_data(2000, partial=True)
                    data = Table(data_sample)
            return data

        def settings(data):
            # get the default encoded state, replacing the position with Inf
            state = VariablesSelection.encode_var_state(
                [list(self.model_selected),
                 list(self.model_other)])
            state = {
                key: (source_ind, np.inf)
                for key, (source_ind, _) in state.items()
            }

            self.openContext(data.domain)
            selected_keys = [
                key for key, (sind, _) in self.variable_state.items()
                if sind == 0
            ]

            if set(selected_keys).issubset(set(state.keys())):
                pass

            if self.__pending_selection_restore is not None:
                self._selection = np.array(self.__pending_selection_restore,
                                           dtype=int)
                self.__pending_selection_restore = None

            # update the defaults state (the encoded state must contain
            # all variables in the input domain)
            state.update(self.variable_state)
            # ... and restore it with saved positions taking precedence over
            # the defaults
            selected, other = VariablesSelection.decode_var_state(
                state, [list(self.model_selected),
                        list(self.model_other)])
            return selected, other

        self.closeContext()
        self.clear()
        self.Warning.no_cont_features.clear()
        self.information()
        data = sql(data)
        if data is not None:
            domain = data.domain
            vars = [
                var for var in chain(domain.variables, domain.metas)
                if var.is_continuous
            ]
            if not len(vars):
                self.Warning.no_cont_features()
                data = None
        self.data = data
        self.init_attr_values()
        if data is not None and len(data):
            self._initialize(data)
            self.model_selected[:], self.model_other[:] = settings(data)
            self.vizrank.stop_and_reset()
            self.vizrank.attrs = self.data.domain.attributes if self.data is not None else []

    def _check_possible_opt(self):
        def set_enabled(is_enabled):
            for btn in self.radio_placement.buttons:
                btn.setEnabled(is_enabled)
            self.variables_selection.set_enabled(is_enabled)

        p_Circular = self.Placement.Circular
        p_LDA = self.Placement.LDA
        p_Input = self.Placement.Projection
        if self.data:
            set_enabled(True)
            domain = self.data.domain
            if not domain.has_discrete_class or len(
                    domain.class_var.values) < 2:
                self.radio_placement.buttons[p_LDA].setEnabled(False)
                if self.placement == p_LDA:
                    self.placement = p_Circular
            if not self.projection:
                self.radio_placement.buttons[p_Input].setEnabled(False)
                if self.placement == p_Input:
                    self.placement = p_Circular
            self._setup_plot()
        else:
            self.graph.new_data(None)
            self.rslider.setEnabled(False)
            set_enabled(False)
        self.commit()

    @Inputs.data_subset
    def set_subset_data(self, subset):
        """
        Set the supplementary input subset dataset.

        Args:
            subset (Orange.data.table): subset of data instances
        """
        self.subset_data = subset
        self._subset_mask = None
        self.controls.graph.alpha_value.setEnabled(subset is None)

    def handleNewSignals(self):
        if self.data is not None and self.subset_data is not None:
            # Update the plot's highlight items
            dataids = self.data.ids.ravel()
            subsetids = np.unique(self.subset_data.ids)
            self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True)
        self._check_possible_opt()
        self._change_placement()
        self.commit()

    def customEvent(self, event):
        if event.type() == OWLinearProjection.ReplotRequest:
            self.__replot_requested = False
            self._setup_plot()
            self.commit()
        else:
            super().customEvent(event)

    def closeContext(self):
        self.variable_state = VariablesSelection.encode_var_state(
            [list(self.model_selected),
             list(self.model_other)])
        super().closeContext()

    def _initialize(self, data):
        # Initialize the GUI controls from data's domain.
        vars = [
            v for v in chain(data.domain.metas, data.domain.attributes)
            if v.is_continuous
        ]
        self.model_other[:] = vars[3:]
        self.model_selected[:] = vars[:3]

    def prepare_plot_data(self, variables):
        def projection(variables):
            if set(self.projection.domain.attributes).issuperset(variables):
                axes = self.projection[:2, variables].X
            elif set(f.name
                     for f in self.projection.domain.attributes).issuperset(
                         f.name for f in variables):
                axes = self.projection[:2, [f.name for f in variables]].X
            else:
                self.Error.proj_and_domain_match()
                axes = None
            return axes

        def get_axes(variables):
            self.Error.proj_and_domain_match.clear()
            axes = None
            if self.placement == self.Placement.Circular:
                axes = LinProj.defaultaxes(len(variables))
            elif self.placement == self.Placement.LDA:
                axes = self._get_lda(self.data, variables)
            elif self.placement == self.Placement.Projection and self.projection:
                axes = projection(variables)
            return axes

        coords = [
            column_data(self.data, var, dtype=float) for var in variables
        ]
        coords = np.vstack(coords)
        p, N = coords.shape
        assert N == len(self.data), p == len(variables)

        axes = get_axes(variables)
        if axes is None:
            return None, None, None
        assert axes.shape == (2, p)

        valid_mask = ~np.isnan(coords).any(axis=0)
        coords = coords[:, valid_mask]

        X, Y = np.dot(axes, coords)
        if X.size and Y.size:
            X = normalized(X)
            Y = normalized(Y)

        return valid_mask, np.stack((X, Y), axis=1), axes.T

    def _setup_plot(self):
        self._clear_plot()
        if self.data is None:
            return
        self.__replot_requested = False
        names = get_unique_names([
            v.name
            for v in chain(self.data.domain.variables, self.data.domain.metas)
        ], [
            "{}-x".format(self.Variable_name[self.placement]), "{}-y".format(
                self.Variable_name[self.placement])
        ])
        self.variable_x = ContinuousVariable(names[0])
        self.variable_y = ContinuousVariable(names[1])
        if self.placement in [self.Placement.Circular, self.Placement.LDA]:
            variables = list(self.model_selected)
        elif self.placement == self.Placement.Projection:
            variables = self.model_selected[:] + self.model_other[:]
        elif self.placement == self.Placement.PCA:
            variables = [
                var for var in self.data.domain.attributes if var.is_continuous
            ]
        if not variables:
            self.graph.new_data(None)
            return
        if self.placement == self.Placement.PCA:
            valid_mask, ec, axes = self._get_pca()
            variables = self._pca.orig_domain.attributes
        else:
            valid_mask, ec, axes = self.prepare_plot_data(variables)

        self.plotdata.variables = variables
        self.plotdata.valid_mask = valid_mask
        self.plotdata.embedding_coords = ec
        self.plotdata.axes = axes
        if any(e is None for e in (valid_mask, ec, axes)):
            return

        if not sum(valid_mask):
            self.Error.no_valid_data()
            self.graph.new_data(None, None)
            return
        self.Error.no_valid_data.clear()

        self._anchor_circle(variables=variables)
        self._plot()

    def _plot(self):
        domain = self.data.domain
        new_metas = domain.metas + (self.variable_x, self.variable_y)
        domain = Domain(attributes=domain.attributes,
                        class_vars=domain.class_vars,
                        metas=new_metas)
        valid_mask = self.plotdata.valid_mask
        array = np.zeros((len(self.data), 2), dtype=np.float)
        array[valid_mask] = self.plotdata.embedding_coords
        self.plotdata.data = data = self.data.transform(domain)
        data[:, self.variable_x] = array[:, 0].reshape(-1, 1)
        data[:, self.variable_y] = array[:, 1].reshape(-1, 1)
        subset_data = data[self._subset_mask & valid_mask]\
            if self._subset_mask is not None and len(self._subset_mask) else None
        self.plotdata.data = data
        self.graph.new_data(data[valid_mask], subset_data)
        if self._selection is not None:
            self.graph.selection = self._selection[valid_mask]
        self.graph.update_data(self.variable_x, self.variable_y, False)

    def _get_lda(self, data, variables):
        domain = Domain(attributes=variables,
                        class_vars=data.domain.class_vars)
        data = data.transform(domain)
        lda = LinearDiscriminantAnalysis(solver='eigen', n_components=2)
        lda.fit(data.X, data.Y)
        scalings = lda.scalings_[:, :2].T
        if scalings.shape == (1, 1):
            scalings = np.array([[1.], [0.]])
        return scalings

    def _get_pca(self):
        data = self.data
        MAX_COMPONENTS = 2
        ncomponents = 2
        DECOMPOSITIONS = [PCA]  # TruncatedSVD
        cls = DECOMPOSITIONS[0]
        pca_projector = cls(n_components=MAX_COMPONENTS)
        pca_projector.component = ncomponents
        pca_projector.preprocessors = cls.preprocessors + [Normalize()]

        pca = pca_projector(data)
        variance_ratio = pca.explained_variance_ratio_
        cumulative = np.cumsum(variance_ratio)

        self._pca = pca
        if not np.isfinite(cumulative[-1]):
            self.Warning.trivial_components()

        coords = pca(data).X
        valid_mask = ~np.isnan(coords).any(axis=1)
        # scale axes
        max_radius = np.min(
            [np.abs(np.min(coords, axis=0)),
             np.max(coords, axis=0)])
        axes = pca.components_.T.copy()
        axes *= max_radius / np.max(np.linalg.norm(axes, axis=1))
        return valid_mask, coords, axes

    def _update_graph(self, reset_view=False):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x, self.variable_y, reset_view)

    def update_density(self):
        self._update_graph(reset_view=False)

    def selection_changed(self):
        if self.graph.selection is not None:
            self._selection = np.zeros(len(self.data), dtype=np.uint8)
            self._selection[self.plotdata.valid_mask] = self.graph.selection
            self.selection_indices = self._selection.tolist()
        else:
            self._selection = self.selection_indices = None
        self.commit()

    def prepare_data(self):
        pass

    def commit(self):
        def prepare_components():
            if self.placement in [self.Placement.Circular, self.Placement.LDA]:
                attrs = [a for a in self.model_selected[:]]
                axes = self.plotdata.axes
            elif self.placement == self.Placement.PCA:
                axes = self._pca.components_.T
                attrs = [a for a in self._pca.orig_domain.attributes]
            if self.placement != self.Placement.Projection:
                domain = Domain([
                    ContinuousVariable(a.name, compute_value=lambda _: None)
                    for a in attrs
                ],
                                metas=[StringVariable(name='component')])
                metas = np.array([[
                    "{}{}".format(self.Component_name[self.placement], i + 1)
                    for i in range(axes.shape[1])
                ]],
                                 dtype=object).T
                components = Table(domain, axes.T, metas=metas)
                components.name = 'components'
            else:
                components = self.projection
            return components

        selected = annotated = components = None
        if self.data is not None and self.plotdata.data is not None:
            components = prepare_components()

            graph = self.graph
            mask = self.plotdata.valid_mask.astype(int)
            mask[mask == 1] = graph.selection if graph.selection is not None \
            else [False * len(mask)]

            selection = np.array(
                [], dtype=np.uint8) if mask is None else np.flatnonzero(mask)
            name = self.data.name
            data = self.plotdata.data
            if len(selection):
                selected = data[selection]
                selected.name = name + ": selected"
                selected.attributes = self.data.attributes

            if graph.selection is not None and np.max(graph.selection) > 1:
                annotated = create_groups_table(data, mask)
            else:
                annotated = create_annotated_table(data, selection)
            annotated.attributes = self.data.attributes
            annotated.name = name + ": annotated"

        self.Outputs.selected_data.send(selected)
        self.Outputs.annotated_data.send(annotated)
        self.Outputs.components.send(components)

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        def projection_name():
            name = ("Circular Placement", "Linear Discriminant Analysis",
                    "Principal Component Analysis", "Input projection")
            return name[self.placement]

        caption = report.render_items_vert(
            (("Projection", projection_name()), ("Color",
                                                 name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            settings_["point_width"] = settings_["point_size"]
        if version < 3:
            settings_graph = {}
            settings_graph["jitter_size"] = settings_["jitter_value"]
            settings_graph["point_width"] = settings_["point_width"]
            settings_graph["alpha_value"] = settings_["alpha_value"]
            settings_graph["class_density"] = settings_["class_density"]
            settings_["graph"] = settings_graph

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            domain = context.ordered_domain
            c_domain = [t for t in context.ordered_domain if t[1] == 2]
            d_domain = [t for t in context.ordered_domain if t[1] == 1]
            for d, old_val, new_val in ((domain, "color_index", "attr_color"),
                                        (d_domain, "shape_index",
                                         "attr_shape"),
                                        (c_domain, "size_index", "attr_size")):
                index = context.values[old_val][0] - 1
                context.values[new_val] = (d[index][0], d[index][1] + 100) \
                    if 0 <= index < len(d) else None
        if version < 3:
            context.values["graph"] = {
                "attr_color": context.values["attr_color"],
                "attr_shape": context.values["attr_shape"],
                "attr_size": context.values["attr_size"]
            }
Exemple #20
0
class OWMap(OWDataProjectionWidget):
    """
    Scatter plot visualization of coordinates data with geographic maps for
    background.
    """

    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/GeoMap.svg"
    priority = 100

    replaces = [
        "Orange.widgets.visualize.owmap.OWMap",
    ]

    settings_version = 3

    attr_lat = settings.ContextSetting(None)
    attr_lon = settings.ContextSetting(None)

    GRAPH_CLASS = OWScatterPlotMapGraph
    graph = settings.SettingProvider(OWScatterPlotMapGraph)
    embedding_variables_names = None

    class Error(OWDataProjectionWidget.Error):
        no_lat_lon_vars = Msg("Data has no latitude and longitude variables.")

    class Warning(OWDataProjectionWidget.Warning):
        missing_coords = Msg("Plot cannot be displayed because '{}' or '{}' "
                             "is missing for all data points")
        out_of_range = Msg(
            "Points with out of range latitude or longitude are not displayed."
        )
        no_internet = Msg("Cannot fetch map from the internet. "
                          "Displaying only cached parts.")

    class Information(OWDataProjectionWidget.Information):
        missing_coords = Msg(
            "Points with missing '{}' or '{}' are not displayed")

    def __init__(self):
        super().__init__()
        self._attr_lat, self._attr_lon = None, None
        self.graph.show_internet_error.connect(self._show_internet_error)

    def _show_internet_error(self, show):
        if not self.Warning.no_internet.is_shown() and show:
            self.Warning.no_internet()
        elif self.Warning.no_internet.is_shown() and not show:
            self.Warning.no_internet.clear()

    def _add_controls(self):
        self.lat_lon_model = DomainModel(DomainModel.MIXED,
                                         valid_types=ContinuousVariable)

        lat_lon_box = gui.vBox(self.controlArea, True)
        options = dict(labelWidth=75,
                       orientation=Qt.Horizontal,
                       sendSelectedValue=True,
                       valueType=str,
                       contentsLength=14)

        gui.comboBox(lat_lon_box,
                     self,
                     'graph.tile_provider_key',
                     label='Map:',
                     items=list(TILE_PROVIDERS.keys()),
                     callback=self.graph.update_tile_provider,
                     **options)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lat',
                     label='Latitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        gui.comboBox(lat_lon_box,
                     self,
                     'attr_lon',
                     label='Longitude:',
                     callback=self.setup_plot,
                     model=self.lat_lon_model,
                     **options,
                     searchable=True)

        super()._add_controls()

        gui.checkBox(
            self._plot_box,
            self,
            value="graph.freeze",
            label="Freeze map",
            tooltip="If checked, the map won't change position to fit new data."
        )

    def get_embedding(self):
        self.valid_data = None
        if self.data is None:
            return None

        lat_data = self.get_column(self.attr_lat, filter_valid=False)
        lon_data = self.get_column(self.attr_lon, filter_valid=False)
        if lat_data is None or lon_data is None:
            return None

        self.Warning.missing_coords.clear()
        self.Information.missing_coords.clear()
        self.valid_data = np.isfinite(lat_data) & np.isfinite(lon_data)
        if self.valid_data is not None and not np.all(self.valid_data):
            msg = self.Information if np.any(self.valid_data) else self.Warning
            msg.missing_coords(self.attr_lat.name, self.attr_lon.name)

        in_range = (-MAX_LONGITUDE <= lon_data) & (lon_data <= MAX_LONGITUDE) &\
                   (-MAX_LATITUDE <= lat_data) & (lat_data <= MAX_LATITUDE)
        in_range = ~np.bitwise_xor(in_range, self.valid_data)
        self.Warning.out_of_range.clear()
        if in_range.sum() != len(lon_data):
            self.Warning.out_of_range()
        if in_range.sum() == 0:
            return None
        self.valid_data &= in_range

        x, y = deg2norm(lon_data, lat_data)
        # invert y to increase from bottom to top
        y = 1 - y
        return np.vstack((x, y)).T

    def check_data(self):
        super().check_data()

        if self.data is not None and (len(self.data) == 0
                                      or len(self.data.domain) == 0):
            self.data = None

    def init_attr_values(self):
        lat, lon = None, None
        if self.data is not None:
            lat, lon = find_lat_lon(self.data, filter_hidden=True)
            if lat is None or lon is None:
                # we either find both or we don't have valid data
                self.Error.no_lat_lon_vars()
                self.data = None
                lat, lon = None, None

        super().init_attr_values()

        self.lat_lon_model.set_domain(self.data.domain if self.data else None)
        self.attr_lat, self.attr_lon = lat, lon

    @property
    def effective_variables(self):
        return [self.attr_lat, self.attr_lon] \
            if self.attr_lat and self.attr_lon else []

    @property
    def effective_data(self):
        eff_var = self.effective_variables
        if eff_var and self.attr_lat.name == self.attr_lon.name:
            eff_var = [self.attr_lat]
        return self.data.transform(Domain(eff_var))

    def showEvent(self, ev):
        super().showEvent(ev)
        # reset the map on show event since before that we didn't know the
        # right resolution
        self.graph.update_view_range()

    def resizeEvent(self, ev):
        super().resizeEvent(ev)
        # when resizing we need to constantly reset the map so that new
        # portions are drawn
        self.graph.update_view_range(match_data=False)

    @classmethod
    def migrate_settings(cls, _settings, version):
        if version < 3:
            _settings["graph"] = {}
            if "tile_provider" in _settings:
                if _settings["tile_provider"] == "Watercolor":
                    _settings["tile_provider"] = DEFAULT_TILE_PROVIDER
                _settings["graph"]["tile_provider_key"] = \
                    _settings["tile_provider"]
            if "opacity" in _settings:
                _settings["graph"]["alpha_value"] = \
                    round(_settings["opacity"] * 2.55)
            if "zoom" in _settings:
                _settings["graph"]["point_width"] = \
                    round(_settings["zoom"] * 0.02)
            if "jittering" in _settings:
                _settings["graph"]["jitter_size"] = _settings["jittering"]
            if "show_legend" in _settings:
                _settings["graph"]["show_legend"] = _settings["show_legend"]

    @classmethod
    def migrate_context(cls, context, version):
        if version < 2:
            settings.migrate_str_to_variable(context,
                                             names="lat_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="lon_attr",
                                             none_placeholder="")
            settings.migrate_str_to_variable(context,
                                             names="class_attr",
                                             none_placeholder="(None)")

            # those settings can have two none placeholder
            attr_placeholders = [("color_attr", "(Same color)"),
                                 ("label_attr", "(No labels)"),
                                 ("shape_attr", "(Same shape)"),
                                 ("size_attr", "(Same size)")]
            for attr, place in attr_placeholders:
                if context.values[attr][0] == place:
                    context.values[attr] = ("", context.values[attr][1])

                settings.migrate_str_to_variable(context,
                                                 names=attr,
                                                 none_placeholder="")
        if version < 3:
            settings.rename_setting(context, "lat_attr", "attr_lat")
            settings.rename_setting(context, "lon_attr", "attr_lon")
            settings.rename_setting(context, "color_attr", "attr_color")
            settings.rename_setting(context, "label_attr", "attr_label")
            settings.rename_setting(context, "shape_attr", "attr_shape")
            settings.rename_setting(context, "size_attr", "attr_size")
Exemple #21
0
class OWTestLearners(OWWidget):
    name = "Test and Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100
    keywords = ['Cross Validation', 'CV']

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message("Click on the table header to select shown columns",
                       "click_header")
    ]

    settingsHandler = settings.PerfectDomainContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    use_rope = settings.Setting(False)
    rope = settings.Setting(0.1)
    comparison_criterion = settings.Setting(0, schema_only=True)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train dataset is empty.")
        test_data_empty = Msg("Test dataset is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg(
            "Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train datasets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        no_class_values = Msg("Target variable has no values.")
        only_one_class_var_value = Msg("Target variable has only one value.")
        test_data_incompatible = Msg(
            "Test data may be incompatible with train data.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")
        test_data_transformed = Msg(
            "Test data has been transformed to match the train data.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []
        self.__pending_comparison_criterion = self.comparison_criterion

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[TaskState]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
        gui.comboBox(box,
                     self,
                     "comparison_criterion",
                     model=PyListModel(),
                     callback=self.update_comparison_table)

        hbox = gui.hBox(box)
        gui.checkBox(hbox,
                     self,
                     "use_rope",
                     "Negligible difference: ",
                     callback=self._on_use_rope_changed)
        gui.lineEdit(hbox,
                     self,
                     "rope",
                     validator=QDoubleValidator(),
                     controlWidth=70,
                     callback=self.update_comparison_table,
                     alignment=Qt.AlignRight)
        self.controls.rope.setEnabled(self.use_rope)

        gui.rubber(self.controlArea)
        self.score_table = ScoreTable(self)
        self.score_table.shownScoresChanged.connect(self.update_stats_model)
        view = self.score_table.view
        view.setSizeAdjustPolicy(view.AdjustToContents)

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.score_table.view)

        self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
        table = self.comparison_table = QTableWidget(
            wordWrap=False,
            editTriggers=QTableWidget.NoEditTriggers,
            selectionMode=QTableWidget.NoSelection)
        table.setSizeAdjustPolicy(table.AdjustToContents)
        header = table.verticalHeader()
        header.setSectionResizeMode(QHeaderView.Fixed)
        header.setSectionsClickable(False)

        header = table.horizontalHeader()
        header.setTextElideMode(Qt.ElideRight)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setSectionsClickable(False)
        header.setStretchLastSection(False)
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        avg_width = self.fontMetrics().averageCharWidth()
        header.setMinimumSectionSize(8 * avg_width)
        header.setMaximumSectionSize(15 * avg_width)
        header.setDefaultSectionSize(15 * avg_width)
        box.layout().addWidget(table)
        box.layout().addWidget(
            QLabel(
                "<small>Table shows probabilities that the score for the model in "
                "the row is higher than that of the model in the column. "
                "Small numbers show the probability that the difference is "
                "negligible.</small>",
                wordWrap=True))

    @staticmethod
    def sizeHint():
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        elif learner is not None:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.cancel()
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.no_class_values.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not data:
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [
                not data.domain.class_vars,
                len(data.domain.class_vars) > 1,
                np.isnan(data.Y).all(), data.domain.has_discrete_class
                and len(data.domain.class_var.values) == 1
            ]
            errors = [
                self.Error.class_required, self.Error.too_many_classes,
                self.Error.no_class_values, self.Error.only_one_class_var_value
            ]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not data:
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {
            (True, True): " ",  # both, don't specify
            (True, False): " train ",
            (False, True): " test "
        }[(self.train_data_missing_vals, self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data and self.data.domain.class_var:
            new_scorers = usable_scorers(self.data.domain.class_var)
        else:
            new_scorers = []
        # Don't unnecessarily reset the model because this would always reset
        # comparison_criterion; we alse set it explicitly, though, for clarity
        if new_scorers != self.scorers:
            self.scorers = new_scorers
            self.controls.comparison_criterion.model()[:] = \
                [scorer.long_name or scorer.name for scorer in self.scorers]
            self.comparison_criterion = 0
        if self.__pending_comparison_criterion is not None:
            # Check for the unlikely case that some scorers have been removed
            # from modules
            if self.__pending_comparison_criterion < len(self.scorers):
                self.comparison_criterion = self.__pending_comparison_criterion
            self.__pending_comparison_criterion = None
        self._update_compbox_title()

    def _update_compbox_title(self):
        criterion = self.comparison_criterion
        if criterion < len(self.scorers):
            scorer = self.scorers[criterion]()
            self.compbox.setTitle(f"Model Comparison by {scorer.name}")
        else:
            self.compbox.setTitle(f"Model Comparison")

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self.score_table.update_header(self.scorers)
        self._update_view_enabled()
        self.update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold)
        self._update_view_enabled()
        self._invalidate()
        self.__update()

    def _update_view_enabled(self):
        self.comparison_table.setEnabled(
            self.resampling == OWTestLearners.KFold and len(self.learners) > 1
            and self.data is not None)
        self.score_table.view.setEnabled(self.data is not None)

    def update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.score_table.model
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        names = []
        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            names.append(name)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            results = slot.results
            if results is not None and results.success:
                train = QStandardItem("{:.3f}".format(
                    results.value.train_time))
                train.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                train.setData(key, Qt.UserRole)
                test = QStandardItem("{:.3f}".format(results.value.test_time))
                test.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                test.setData(key, Qt.UserRole)
                row = [head, train, test]
            else:
                row = [head]
            if isinstance(results, Try.Fail):
                head.setToolTip(str(results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                if isinstance(results.exception, DomainTransformationError) \
                        and self.resampling == self.TestOnTest:
                    self.Error.test_data_incompatible()
                    self.Information.test_data_transformed.clear()
                else:
                    errors.append("{name} failed with error:\n"
                                  "{exc.__class__.__name__}: {exc!s}".format(
                                      name=name, exc=slot.results.exception))

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(slot.results.value,
                                                      target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [
                        Try(scorer_caller(scorer, ovr_results, target=1))
                        for scorer in self.scorers
                    ]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat, scorer in zip(stats, self.scorers):
                    item = QStandardItem()
                    item.setTextAlignment(Qt.AlignRight | Qt.AlignVCenter)
                    if stat.success:
                        item.setData(float(stat.value[0]), Qt.DisplayRole)
                    else:
                        item.setToolTip(str(stat.exception))
                        if scorer.name in self.score_table.shown_scores:
                            has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        # Resort rows based on current sorting
        header = self.score_table.view.horizontalHeader()
        model.sort(header.sortIndicatorSection(), header.sortIndicatorOrder())
        self._set_comparison_headers(names)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _on_use_rope_changed(self):
        self.controls.rope.setEnabled(self.use_rope)
        self.update_comparison_table()

    def update_comparison_table(self):
        self.comparison_table.clearContents()
        slots = self._successful_slots()
        if not (slots and self.scorers):
            return
        names = [learner_name(slot.learner) for slot in slots]
        self._set_comparison_headers(names)
        if self.resampling == OWTestLearners.KFold:
            scores = self._scores_by_folds(slots)
            self._fill_table(names, scores)

    def _successful_slots(self):
        model = self.score_table.model
        proxy = self.score_table.sorted_model

        keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
                for row in range(proxy.rowCount()))
        slots = [
            slot for slot in (self.learners[key] for key in keys)
            if slot.results is not None and slot.results.success
        ]
        return slots

    def _set_comparison_headers(self, names):
        table = self.comparison_table
        try:
            # Prevent glitching during update
            table.setUpdatesEnabled(False)
            header = table.horizontalHeader()
            if len(names) > 2:
                header.setSectionResizeMode(QHeaderView.Stretch)
            else:
                header.setSectionResizeMode(QHeaderView.Fixed)
            table.setRowCount(len(names))
            table.setColumnCount(len(names))
            table.setVerticalHeaderLabels(names)
            table.setHorizontalHeaderLabels(names)
        finally:
            table.setUpdatesEnabled(True)

    def _scores_by_folds(self, slots):
        scorer = self.scorers[self.comparison_criterion]()
        self._update_compbox_title()
        if scorer.is_binary:
            if self.class_selection != self.TARGET_AVERAGE:
                class_var = self.data.domain.class_var
                target_index = class_var.values.index(self.class_selection)
                kw = dict(target=target_index)
            else:
                kw = dict(average='weighted')
        else:
            kw = {}

        def call_scorer(results):
            def thunked():
                return scorer.scores_by_folds(results.value, **kw).flatten()

            return thunked

        scores = [Try(call_scorer(slot.results)) for slot in slots]
        scores = [score.value if score.success else None for score in scores]
        # `None in scores doesn't work -- these are np.arrays)
        if any(score is None for score in scores):
            self.Warning.scores_not_computed()
        return scores

    def _fill_table(self, names, scores):
        table = self.comparison_table
        for row, row_name, row_scores in zip(count(), names, scores):
            for col, col_name, col_scores in zip(range(row), names, scores):
                if row_scores is None or col_scores is None:
                    continue
                if self.use_rope and self.rope:
                    p0, rope, p1 = baycomp.two_on_single(
                        row_scores, col_scores, self.rope)
                    if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(
                        table, row, col,
                        f"{p0:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({row_name} > {col_name}) = {p0:.3f}\n"
                        f"p({row_name} = {col_name}) = {rope:.3f}")
                    self._set_cell(
                        table, col, row,
                        f"{p1:.3f}<br/><small>{rope:.3f}</small>",
                        f"p({col_name} > {row_name}) = {p1:.3f}\n"
                        f"p({col_name} = {row_name}) = {rope:.3f}")
                else:
                    p0, p1 = baycomp.two_on_single(row_scores, col_scores)
                    if np.isnan(p0) or np.isnan(p1):
                        self._set_cells_na(table, row, col)
                        continue
                    self._set_cell(table, row, col, f"{p0:.3f}",
                                   f"p({row_name} > {col_name}) = {p0:.3f}")
                    self._set_cell(table, col, row, f"{p1:.3f}",
                                   f"p({col_name} > {row_name}) = {p1:.3f}")

    @classmethod
    def _set_cells_na(cls, table, row, col):
        cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
        cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

    @staticmethod
    def _set_cell(table, row, col, label, tooltip):
        item = QLabel(label)
        item.setToolTip(tooltip)
        item.setAlignment(Qt.AlignCenter)
        table.setCellWidget(row, col, item)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self.update_stats_model()
        self.update_comparison_table()

    def _invalidate(self, which=None):
        self.cancel()
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.score_table.model
        statmodelkeys = [
            model.item(row, 0).data(Qt.UserRole)
            for row in range(model.rowCount())
        ]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.comparison_table.clearContents()

        self.__needupdate = True

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [
            slot for slot in self.learners.values()
            if slot.results is not None and slot.results.success
        ]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [
                learner_name(slot.learner) for slot in valid
            ]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(
                    combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".format(
                stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [
                ("Sampling type",
                 "{}Shuffle split, {} random samples with {}% data ".format(
                     stratified, self.NRepeats[self.n_repeats],
                     self.SampleSizes[self.sample_size]))
            ]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.score_table.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')
            ]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Error.test_data_incompatible.clear()
        self.Warning.test_data_missing.clear()
        self.Information.test_data_transformed(
            shown=self.resampling == self.TestOnTest and self.data is not None
            and self.test_data is not None and
            self.data.domain.attributes != self.test_data.domain.attributes)
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData(store_data=True,
                                                 store_models=True), self.data,
                self.test_data, learners_c, self.preprocessor)
        else:
            if self.resampling == OWTestLearners.KFold:
                sampler = Orange.evaluation.CrossValidation(
                    k=self.NFolds[self.n_folds], random_state=rstate)
            elif self.resampling == OWTestLearners.FeatureFold:
                sampler = Orange.evaluation.CrossValidationFeature(
                    feature=self.fold_feature)
            elif self.resampling == OWTestLearners.LeaveOneOut:
                sampler = Orange.evaluation.LeaveOneOut()
            elif self.resampling == OWTestLearners.ShuffleSplit:
                sampler = Orange.evaluation.ShuffleSplit(
                    n_resamples=self.NRepeats[self.n_repeats],
                    train_size=self.SampleSizes[self.sample_size] / 100,
                    test_size=None,
                    stratified=self.shuffle_stratified,
                    random_state=rstate)
            elif self.resampling == OWTestLearners.TestOnTrain:
                sampler = Orange.evaluation.TestOnTrainingData(
                    store_models=True)
            else:
                assert False, "self.resampling %s" % self.resampling

            sampler.store_data = True
            test_f = partial(sampler, self.data, learners_c, self.preprocessor)

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[[float], None]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = TaskState()

        def progress_callback(finished):
            if task.is_interruption_requested():
                raise UserInterrupt()
            task.set_progress_value(100 * finished)

        testfunc = partial(testfunc, callback=progress_callback)
        task.start(self.__executor, testfunc)

        task.progress_changed.connect(self.setProgressValue)
        task.watcher.finished.connect(self.__task_complete)

        self.Outputs.evaluations_results.invalidate()
        self.Outputs.predictions.invalidate()
        self.progressBarInit()
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, f: 'Future[Results]'):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        assert self.__task is not None and self.__task.future is f
        self.progressBarFinished()
        self.setStatusMessage("")
        assert f.done()
        self.__task = None
        self.__state = State.Done
        try:
            results = f.result()  # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:  # pylint: disable=broad-except
            log.exception("testing error (in __task_complete):", exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er),
                                                                 er)))
            return

        learner_key = {
            slot.learner: key
            for key, slot in self.learners.items()
        }
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [
                        Try(scorer_caller(scorer, result))
                        for scorer in self.scorers
                    ]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self.score_table.update_header(self.scorers)
        self.update_stats_model()
        self.update_comparison_table()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            task.progress_changed.disconnect(self.setProgressValue)
            task.watcher.finished.disconnect(self.__task_complete)
            self.progressBarFinished()
            self.setStatusMessage("")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
class OWPredictions(OWWidget):
    name = "Predictions"
    icon = "icons/Predictions.svg"
    priority = 200
    description = "Display the predictions of models for an input dataset."
    keywords = []

    class Inputs:
        data = Input("Data", Orange.data.Table)
        predictors = Input("Predictors", Model, multiple=True)

    class Outputs:
        predictions = Output("Predictions", Orange.data.Table)
        evaluation_results = Output("Evaluation Results",
                                    Orange.evaluation.Results,
                                    dynamic=False)

    class Warning(OWWidget.Warning):
        empty_data = Msg("Empty dataset")

    class Error(OWWidget.Error):
        predictor_failed = \
            Msg("One or more predictors failed (see more...)\n{}")
        scorer_failed = \
            Msg("One or more scorers failed (see more...)\n{}")
        predictors_target_mismatch = \
            Msg("Predictors do not have the same target.")
        data_target_mismatch = \
            Msg("Data does not have the same target as predictors.")

    settingsHandler = settings.ClassValuesContextHandler()
    score_table = settings.SettingProvider(ScoreTable)

    #: Display the full input dataset or only the target variable columns (if
    #: available)
    show_attrs = settings.Setting(True)
    #: Show predicted values (for discrete target variable)
    show_predictions = settings.Setting(True)
    #: Show predictions probabilities (for discrete target variable)
    show_probabilities = settings.Setting(True)
    #: List of selected class value indices in the "Show probabilities" list
    selected_classes = settings.ContextSetting([])
    #: Draw colored distribution bars
    draw_dist = settings.Setting(True)

    output_attrs = settings.Setting(True)
    output_predictions = settings.Setting(True)
    output_probabilities = settings.Setting(True)

    def __init__(self):
        super().__init__()

        #: Input data table
        self.data = None  # type: Optional[Orange.data.Table]
        #: A dict mapping input ids to PredictorSlot
        self.predictors = OrderedDict()  # type: Dict[object, PredictorSlot]
        #: A class variable (prediction target)
        self.class_var = None  # type: Optional[Orange.data.Variable]
        #: List of (discrete) class variable's values
        self.class_values = []  # type: List[str]

        box = gui.vBox(self.controlArea, "Info")
        self.infolabel = gui.widgetLabel(
            box, "No data on input.\nPredictors: 0\nTask: N/A")
        self.infolabel.setMinimumWidth(150)
        gui.button(box,
                   self,
                   "Restore Original Order",
                   callback=self._reset_order,
                   tooltip="Show rows in the original order")

        self.classification_options = box = gui.vBox(self.controlArea,
                                                     "Show",
                                                     spacing=-1,
                                                     addSpace=False)

        gui.checkBox(box,
                     self,
                     "show_predictions",
                     "Predicted class",
                     callback=self._update_prediction_delegate)
        b = gui.checkBox(box,
                         self,
                         "show_probabilities",
                         "Predicted probabilities for:",
                         callback=self._update_prediction_delegate)
        ibox = gui.indentedBox(box,
                               sep=gui.checkButtonOffsetHint(b),
                               addSpace=False)
        gui.listBox(ibox,
                    self,
                    "selected_classes",
                    "class_values",
                    callback=self._update_prediction_delegate,
                    selectionMode=QListWidget.MultiSelection,
                    addSpace=False)
        gui.checkBox(box,
                     self,
                     "draw_dist",
                     "Draw distribution bars",
                     callback=self._update_prediction_delegate)

        box = gui.vBox(self.controlArea, "Data View")
        gui.checkBox(box,
                     self,
                     "show_attrs",
                     "Show full dataset",
                     callback=self._update_column_visibility)

        box = gui.vBox(self.controlArea, "Output", spacing=-1)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_attrs",
                                           "Original data",
                                           callback=self.commit)
        self.checkbox_class = gui.checkBox(box,
                                           self,
                                           "output_predictions",
                                           "Predictions",
                                           callback=self.commit)
        self.checkbox_prob = gui.checkBox(box,
                                          self,
                                          "output_probabilities",
                                          "Probabilities",
                                          callback=self.commit)

        gui.rubber(self.controlArea)

        self.vsplitter = QSplitter(orientation=Qt.Vertical,
                                   childrenCollapsible=True,
                                   handleWidth=2)

        self.splitter = QSplitter(
            orientation=Qt.Horizontal,
            childrenCollapsible=False,
            handleWidth=2,
        )
        self.dataview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus)
        self.predictionsview = TableView(
            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
            horizontalScrollMode=QTableView.ScrollPerPixel,
            selectionMode=QTableView.NoSelection,
            focusPolicy=Qt.StrongFocus,
            sortingEnabled=True,
        )

        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.dataview.verticalHeader().hide()

        dsbar = self.dataview.verticalScrollBar()
        psbar = self.predictionsview.verticalScrollBar()

        psbar.valueChanged.connect(dsbar.setValue)
        dsbar.valueChanged.connect(psbar.setValue)

        self.dataview.verticalHeader().setDefaultSectionSize(22)
        self.predictionsview.verticalHeader().setDefaultSectionSize(22)
        self.dataview.verticalHeader().sectionResized.connect(
            lambda index, _, size: self.predictionsview.verticalHeader(
            ).resizeSection(index, size))

        self.splitter.addWidget(self.predictionsview)
        self.splitter.addWidget(self.dataview)

        self.score_table = ScoreTable(self)
        self.vsplitter.addWidget(self.splitter)
        self.vsplitter.addWidget(self.score_table.view)
        self.vsplitter.setStretchFactor(0, 5)
        self.vsplitter.setStretchFactor(1, 1)
        self.mainArea.layout().addWidget(self.vsplitter)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set the input dataset"""
        if data is not None and not data:
            data = None
            self.Warning.empty_data()
        else:
            self.Warning.empty_data.clear()

        self.data = data
        if data is None:
            self.dataview.setModel(None)
            self.predictionsview.setModel(None)
            self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        else:
            # force full reset of the view's HeaderView state
            self.dataview.setModel(None)
            model = TableModel(data, parent=None)
            modelproxy = TableSortProxyModel()
            modelproxy.setSourceModel(model)
            self.dataview.setModel(modelproxy)
            self._update_column_visibility()

        self._invalidate_predictions()

    # pylint: disable=redefined-builtin
    @Inputs.predictors
    def set_predictor(self, predictor=None, id=None):
        if id in self.predictors:
            if predictor is not None:
                self.predictors[id] = self.predictors[id]._replace(
                    predictor=predictor, name=predictor.name, results=None)
            else:
                del self.predictors[id]
        elif predictor is not None:
            self.predictors[id] = \
                PredictorSlot(predictor, predictor.name, None)

    def _set_class_var(self):
        pred_classes = set(pred.predictor.domain.class_var
                           for pred in self.predictors.values())
        self.Error.predictors_target_mismatch.clear()
        self.Error.data_target_mismatch.clear()
        self.class_var = None
        if len(pred_classes) > 1:
            self.Error.predictors_target_mismatch()
        if len(pred_classes) == 1:
            self.class_var = pred_classes.pop()
            if self.data is not None and \
                    self.data.domain.class_var is not None and \
                    self.class_var != self.data.domain.class_var:
                self.Error.data_target_mismatch()
                self.class_var = None

        discrete_class = self.class_var is not None \
                         and self.class_var.is_discrete
        self.classification_options.setVisible(discrete_class)
        self.closeContext()
        if discrete_class:
            self.class_values = list(self.class_var.values)
            self.selected_classes = list(range(len(self.class_values)))
            self.openContext(self.class_var)
        else:
            self.class_values = []
            self.selected_classes = []

    def handleNewSignals(self):
        self._set_class_var()
        self._call_predictors()
        self._update_scores()
        self._update_predictions_model()
        self._update_prediction_delegate()
        self._set_errors()
        self._update_info()
        self.commit()

    def _call_predictors(self):
        if not self.data:
            return
        for inputid, slot in self.predictors.items():
            if slot.results is not None \
                    and not isinstance(slot.results, str) \
                    and not numpy.isnan(slot.results.predicted[0]).all():
                continue
            try:
                pred, prob = self.predict(slot.predictor, self.data)
            except (ValueError, DomainTransformationError) as err:
                results = "{}: {}".format(slot.predictor.name, err)
            else:
                results = Orange.evaluation.Results()
                results.data = self.data
                results.domain = self.data.domain
                results.row_indices = numpy.arange(len(self.data))
                results.folds = (Ellipsis, )
                results.actual = self.data.Y
                results.predicted = pred.reshape((1, len(self.data)))
                results.probabilities = prob.reshape((1, ) + prob.shape)
            self.predictors[inputid] = slot._replace(results=results)

    def _update_scores(self):
        model = self.score_table.model
        model.clear()

        if self.data is None or self.data.domain.class_var is None:
            scorers = []
        else:
            scorers = usable_scorers(self.data.domain.class_var)
        self.score_table.update_header(scorers)

        errors = []
        for inputid, pred in self.predictors.items():
            name = learner_name(pred.predictor)
            head = QStandardItem(name)
            #            head.setData(key, Qt.UserRole)
            row = [head]
            results = self.predictors[inputid].results
            if isinstance(results, str):
                head.setToolTip(results)
                head.setText("{} (error)".format(name))
                head.setForeground(QBrush(Qt.red))
            else:
                for scorer in scorers:
                    item = QStandardItem()
                    try:
                        score = scorer_caller(scorer, results)()[0]
                        item.setText(f"{score:.3f}")
                    except Exception as exc:  # pylint: disable=broad-except
                        item.setToolTip(str(exc))
                        if scorer.name in self.score_table.shown_scores:
                            errors.append(str(exc))
                    row.append(item)
            self.score_table.model.appendRow(row)
        self.Error.scorer_failed("\n".join(errors), shown=bool(errors))

    def _set_errors(self):
        # Not all predictors are run every time, so errors can't be collected
        # in _call_predictors
        errors = "\n".join(p.results for p in self.predictors.values()
                           if isinstance(p.results, str))
        self.Error.predictor_failed(errors, shown=bool(errors))

    def _update_info(self):
        info = []
        if self.data is not None:
            info.append("Data: {} instances.".format(len(self.data)))
        else:
            info.append("Data: N/A")

        n_predictors = len(self.predictors)
        n_valid = len(self._valid_predictors())
        if n_valid != n_predictors:
            info.append("Predictors: {} (+ {} failed)".format(
                n_valid, n_predictors - n_valid))
        else:
            info.append("Predictors: {}".format(n_predictors or "N/A"))

        if self.class_var is None:
            info.append("Task: N/A")
        elif self.class_var.is_discrete:
            info.append("Task: Classification")
            self.checkbox_class.setEnabled(True)
            self.checkbox_prob.setEnabled(True)
        else:
            info.append("Task: Regression")
            self.checkbox_class.setEnabled(False)
            self.checkbox_prob.setEnabled(False)

        self.infolabel.setText("\n".join(info))

    def _invalidate_predictions(self):
        for inputid, pred in list(self.predictors.items()):
            self.predictors[inputid] = pred._replace(results=None)

    def _valid_predictors(self):
        if self.class_var is not None and self.data is not None:
            return [
                p for p in self.predictors.values()
                if p.results is not None and not isinstance(p.results, str)
            ]
        else:
            return []

    def _update_predictions_model(self):
        """Update the prediction view model."""
        if self.data is not None and self.class_var is not None:
            slots = self._valid_predictors()
            results = []
            class_var = self.class_var
            for p in slots:
                if isinstance(p.results, str):
                    continue
                values = p.results.predicted[0]
                if self.class_var.is_discrete:
                    # if values were added to class_var between building the
                    # model and predicting, add zeros for new class values,
                    # which are always at the end
                    prob = p.results.probabilities[0]
                    prob = numpy.c_[prob,
                                    numpy.zeros(
                                        (prob.shape[0], len(class_var.values) -
                                         prob.shape[1]))]
                    values = [Value(class_var, v) for v in values]
                else:
                    prob = numpy.zeros((len(values), 0))
                results.append((values, prob))
            results = list(zip(*(zip(*res) for res in results)))
            headers = [p.name for p in slots]
            model = PredictionsModel(results, headers)
        else:
            model = None

        predmodel = PredictionsSortProxyModel()
        predmodel.setSourceModel(model)
        predmodel.setDynamicSortFilter(True)
        self.predictionsview.setItemDelegate(PredictionsItemDelegate())
        self.predictionsview.setModel(predmodel)
        hheader = self.predictionsview.horizontalHeader()
        hheader.setSortIndicatorShown(False)
        # SortFilterProxyModel is slow due to large abstraction overhead
        # (every comparison triggers multiple `model.index(...)`,
        # model.rowCount(...), `model.parent`, ... calls)
        hheader.setSectionsClickable(predmodel.rowCount() < 20000)

        predmodel.layoutChanged.connect(self._update_data_sort_order)
        self._update_data_sort_order()
        self.predictionsview.resizeColumnsToContents()

    def _update_column_visibility(self):
        """Update data column visibility."""
        if self.data is not None and self.class_var is not None:
            domain = self.data.domain
            first_attr = len(domain.class_vars) + len(domain.metas)

            for i in range(first_attr, first_attr + len(domain.attributes)):
                self.dataview.setColumnHidden(i, not self.show_attrs)
            if domain.class_var:
                self.dataview.setColumnHidden(0, False)

    def _update_data_sort_order(self):
        """Update data row order to match the current predictions view order"""
        datamodel = self.dataview.model()  # data model proxy
        predmodel = self.predictionsview.model()  # predictions model proxy
        sortindicatorshown = False
        if datamodel is not None:
            assert isinstance(datamodel, TableSortProxyModel)
            n = datamodel.rowCount()
            if predmodel is not None and predmodel.sortColumn() >= 0:
                sortind = numpy.argsort([
                    predmodel.mapToSource(predmodel.index(i, 0)).row()
                    for i in range(n)
                ])
                sortind = numpy.array(sortind, numpy.int)
                sortindicatorshown = True
            else:
                sortind = None

            datamodel.setSortIndices(sortind)

        self.predictionsview.horizontalHeader() \
            .setSortIndicatorShown(sortindicatorshown)

    def _reset_order(self):
        """Reset the row sorting to original input order."""
        datamodel = self.dataview.model()
        predmodel = self.predictionsview.model()
        if datamodel is not None:
            datamodel.sort(-1)
        if predmodel is not None:
            predmodel.sort(-1)
        self.predictionsview.horizontalHeader().setSortIndicatorShown(False)

    def _update_prediction_delegate(self):
        """Update the predicted probability visibility state"""
        if self.class_var is not None:
            delegate = PredictionsItemDelegate()
            if self.class_var.is_continuous:
                self._setup_delegate_continuous(delegate)
            else:
                self._setup_delegate_discrete(delegate)
                proxy = self.predictionsview.model()
                if proxy is not None:
                    proxy.setProbInd(
                        numpy.array(self.selected_classes, dtype=int))
            self.predictionsview.setItemDelegate(delegate)
            self.predictionsview.resizeColumnsToContents()
        self._update_spliter()

    def _setup_delegate_discrete(self, delegate):
        colors = [QColor(*rgb) for rgb in self.class_var.colors]
        fmt = []
        if self.show_probabilities:
            fmt.append(" : ".join("{{dist[{}]:.2f}}".format(i)
                                  for i in sorted(self.selected_classes)))
        if self.show_predictions:
            fmt.append("{value!s}")
        delegate.setFormat(" \N{RIGHTWARDS ARROW} ".join(fmt))
        if self.draw_dist and colors is not None:
            delegate.setColors(colors)
        return delegate

    def _setup_delegate_continuous(self, delegate):
        delegate.setFormat("{{value:{}}}".format(
            self.class_var.format_str[1:]))

    def _update_spliter(self):
        if self.data is None:
            return

        def width(view):
            h_header = view.horizontalHeader()
            v_header = view.verticalHeader()
            return h_header.length() + v_header.width()

        w = width(self.predictionsview) + 4
        w1, w2 = self.splitter.sizes()
        self.splitter.setSizes([w, w1 + w2 - w])

    def commit(self):
        self._commit_predictions()
        self._commit_evaluation_results()

    def _commit_evaluation_results(self):
        slots = self._valid_predictors()
        if not slots or self.data.domain.class_var is None:
            self.Outputs.evaluation_results.send(None)
            return

        class_var = self.class_var
        nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
        data = self.data[~nanmask]
        results = Orange.evaluation.Results(data, store_data=True)
        results.folds = None
        results.row_indices = numpy.arange(len(data))
        results.actual = data.Y.ravel()
        results.predicted = numpy.vstack(
            tuple(p.results.predicted[0][~nanmask] for p in slots))
        if class_var and class_var.is_discrete:
            results.probabilities = numpy.array(
                [p.results.probabilities[0][~nanmask] for p in slots])
        results.learner_names = [p.name for p in slots]
        self.Outputs.evaluation_results.send(results)

    def _commit_predictions(self):
        slots = self._valid_predictors()
        if not slots:
            self.Outputs.predictions.send(None)
            return

        if self.class_var and self.class_var.is_discrete:
            newmetas, newcolumns = self._classification_output_columns()
        else:
            newmetas, newcolumns = self._regression_output_columns()

        attrs = list(self.data.domain.attributes) if self.output_attrs else []
        metas = list(self.data.domain.metas) + newmetas
        domain = \
            Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
        predictions = self.data.transform(domain)
        if newcolumns:
            newcolumns = numpy.hstack(
                [numpy.atleast_2d(cols) for cols in newcolumns])
            predictions.metas[:, -newcolumns.shape[1]:] = newcolumns
        self.Outputs.predictions.send(predictions)

    def _classification_output_columns(self):
        newmetas = []
        newcolumns = []
        slots = self._valid_predictors()
        if self.output_predictions:
            newmetas += [
                DiscreteVariable(name=p.name, values=self.class_values)
                for p in slots
            ]
            newcolumns += [
                p.results.predicted[0].reshape((-1, 1)) for p in slots
            ]

        if self.output_probabilities:
            newmetas += [
                ContinuousVariable(name="%s (%s)" % (p.name, value))
                for p in slots for value in self.class_values
            ]
            newcolumns += [p.results.probabilities[0] for p in slots]
        return newmetas, newcolumns

    def _regression_output_columns(self):
        slots = self._valid_predictors()
        newmetas = [ContinuousVariable(name=p.name) for p in slots]
        newcolumns = [p.results.predicted[0].reshape((-1, 1)) for p in slots]
        return newmetas, newcolumns

    def send_report(self):
        def merge_data_with_predictions():
            data_model = self.dataview.model()
            predictions_model = self.predictionsview.model()

            # use ItemDelegate to style prediction values
            style = lambda x: self.predictionsview.itemDelegate().displayText(
                x, QLocale())

            # iterate only over visible columns of data's QTableView
            iter_data_cols = list(
                filter(lambda x: not self.dataview.isColumnHidden(x),
                       range(data_model.columnCount())))

            # print header
            yield [''] + \
                  [predictions_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in range(predictions_model.columnCount())] + \
                  [data_model.headerData(col, Qt.Horizontal, Qt.DisplayRole)
                   for col in iter_data_cols]

            # print data & predictions
            for i in range(data_model.rowCount()):
                yield [data_model.headerData(i, Qt.Vertical, Qt.DisplayRole)] + \
                      [style(predictions_model.data(predictions_model.index(i, j)))
                       for j in range(predictions_model.columnCount())] + \
                      [data_model.data(data_model.index(i, j))
                       for j in iter_data_cols]

        if self.data is not None and self.class_var is not None:
            text = self.infolabel.text().replace('\n', '<br>')
            if self.show_probabilities and self.selected_classes:
                text += '<br>Showing probabilities for: '
                text += ', '.join(
                    [self.class_values[i] for i in self.selected_classes])
            self.report_paragraph('Info', text)
            self.report_table("Data & Predictions",
                              merge_data_with_predictions(),
                              header_rows=1,
                              header_columns=1)

    @classmethod
    def predict(cls, predictor, data):
        class_var = predictor.domain.class_var
        if class_var:
            if class_var.is_discrete:
                return cls.predict_discrete(predictor, data)
            else:
                return cls.predict_continuous(predictor, data)
        return None

    @staticmethod
    def predict_discrete(predictor, data):
        return predictor(data, Model.ValueProbs)

    @staticmethod
    def predict_continuous(predictor, data):
        values = predictor(data, Model.Value)
        return values, numpy.zeros((len(data), 0))