Ejemplo n.º 1
0
    def test_executor(self):
        executor = ThreadExecutor()
        f1 = executor.submit(pow, 100, 100)

        f2 = executor.submit(lambda: 1 / 0)

        f3 = executor.submit(QThread.currentThread)

        self.assertTrue(f1.result(), pow(100, 100))

        with self.assertRaises(ZeroDivisionError):
            f2.result()

        self.assertIsInstance(f2.exception(), ZeroDivisionError)

        self.assertIsNot(f3.result(), QThread.currentThread())
Ejemplo n.º 2
0
    def test_executor(self):
        executor = ThreadExecutor()
        f1 = executor.submit(pow, 100, 100)

        f2 = executor.submit(lambda: 1 / 0)

        f3 = executor.submit(QThread.currentThread)

        self.assertTrue(f1.result(), pow(100, 100))

        with self.assertRaises(ZeroDivisionError):
            f2.result()

        self.assertIsInstance(f2.exception(), ZeroDivisionError)

        self.assertIsNot(f3.result(), QThread.currentThread())
Ejemplo n.º 3
0
    def test_methodinvoke(self):
        executor = ThreadExecutor()
        state = [None, None]

        class StateSetter(QObject):
            @pyqtSlot(object)
            def set_state(self, value):
                state[0] = value
                state[1] = QThread.currentThread()

        def func(callback):
            callback(QThread.currentThread())

        obj = StateSetter()
        f1 = executor.submit(func, methodinvoke(obj, "set_state", (object, )))
        f1.result()
        # So invoked method can be called from the event loop
        self.app.processEvents()

        self.assertIs(
            state[1],
            QThread.currentThread(),
            "set_state was called from the wrong thread",
        )

        self.assertIsNot(
            state[0],
            QThread.currentThread(),
            "set_state was invoked in the main thread",
        )

        executor.shutdown(wait=True)
Ejemplo n.º 4
0
    def test_methodinvoke(self):
        executor = ThreadExecutor()
        state = [None, None]

        class StateSetter(QObject):
            @pyqtSlot(object)
            def set_state(self, value):
                state[0] = value
                state[1] = QThread.currentThread()

        def func(callback):
            callback(QThread.currentThread())

        obj = StateSetter()
        f1 = executor.submit(func, methodinvoke(obj, "set_state", (object,)))
        f1.result()
        # So invoked method can be called from the event loop
        self.app.processEvents()

        self.assertIs(state[1], QThread.currentThread(),
                      "set_state was called from the wrong thread")

        self.assertIsNot(state[0], QThread.currentThread(),
                         "set_state was invoked in the main thread")

        executor.shutdown(wait=True)
Ejemplo n.º 5
0
class AgentPlayMixin():
    playing = False

    episodes_interval = 0.0
    games_interval = 0.0

    def play(self):
        self._executor = ThreadExecutor()

        self.environment = gym.make(self.environment_id)

        self.playing = True

        state = self.environment.reset()

        self.environment.render()

        self._executor.submit(partial(self.play_task, state))

    def stop(self):
        self.playing = False

    def play_action(self, state):
        pass

    def play_task(self, state):
        while self.playing:
            # pylint: disable=assignment-from-no-return
            action = self.play_action(state)

            sleep(self.episodes_interval)

            _new_state, _reward, done, _info = self.environment.step(action)

            self.environment.render()

            if done:
                sleep(self.games_interval)

                state = self.environment.reset()

                self.environment.render()

        self.environment.close()
Ejemplo n.º 6
0
    def test_executor(self):
        executor = ThreadExecutor()

        f = executor.submit(QThread.currentThread)

        self.assertIsNot(f.result(3), QThread.currentThread())

        f = executor.submit(lambda: 1 / 0)

        with self.assertRaises(ZeroDivisionError):
            f.result()

        results = []
        task = Task(function=QThread.currentThread)
        task.resultReady.connect(results.append, Qt.DirectConnection)

        f = executor.submit_task(task)

        self.assertIsNot(f.result(3), QThread.currentThread())

        executor.shutdown()
Ejemplo n.º 7
0
    def test_executor(self):
        executor = ThreadExecutor()

        f = executor.submit(QThread.currentThread)

        self.assertIsNot(f.result(3), QThread.currentThread())

        f = executor.submit(lambda: 1 / 0)

        with self.assertRaises(ZeroDivisionError):
            f.result()

        results = []
        task = Task(function=QThread.currentThread)
        task.resultReady.connect(results.append, Qt.DirectConnection)

        f = executor.submit_task(task)

        self.assertIsNot(f.result(3), QThread.currentThread())

        executor.shutdown()
Ejemplo n.º 8
0
class PNNM(OWWidget):
    name = "Pytorch CNN"
    description = ""
    # icon = "icons/robot.svg"

    want_main_area = True

    class Inputs:
        data = Input('Data', ImageDataBunch, default=True)

    def __init__(self):
        super().__init__()
        self.learn = None

        # train_button = gui.button(self.controlArea, self, "开始训练", callback=self.train)
        self.label = gui.label(self.mainArea, self, "模型结构")

        #: The current evaluating task (if any)
        self._task = None  # type: Optional[Task]
        #: An executor we use to submit learner evaluations into a thread pool
        self._executor = ThreadExecutor()

        # Device configuration
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # Hyper parameters
        num_epochs = 5
        num_classes = 10
        batch_size = 100
        learning_rate = 0.001

        dir_path = Path(__file__).resolve()
        data_path = f'{dir_path.parent.parent.parent}/datasets/'

        # MNIST dataset
        self.train_dataset = torchvision.datasets.MNIST(root=data_path,
                                                   train=True,
                                                   transform=transforms.ToTensor(),
                                                   download=False)

        self.test_dataset = torchvision.datasets.MNIST(root=data_path,
                                                  train=False,
                                                  transform=transforms.ToTensor())

        # Data loader
        self.train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=False)

        self.test_loader = torch.utils.data.DataLoader(dataset=self.test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        # self.model = ConvNet(num_classes).to(self.device)
        self.model = nn.Sequential(
            self.conv(1, 8),  # 14
            nn.BatchNorm2d(8),
            nn.ReLU(),
            self.conv(8, 16),  # 7
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self.conv(16, 32),  # 4
            nn.BatchNorm2d(32),
            nn.ReLU(),
            self.conv(32, 16),  # 2
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self.conv(16, 10),  # 1
            nn.BatchNorm2d(10),
            Flatten()  # remove (1,1) grid
        ).to(self.device)

        # Loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

    def handleNewSignals(self):
        self._update()

    def _update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        if self.data is None:
            return
        # collect all learners for which results have not yet been computed
        if not self.learn:
            return

        # setup the task state
        self._task = task = Task()
        # The learning_curve[_with_test_data] also takes a callback function
        # to report the progress. We instrument this callback to both invoke
        # the appropriate slots on this widget for reporting the progress
        # (in a thread safe manner) and to implement cooperative cancellation.
        set_progress = methodinvoke(self, "setProgressValue", (float,))

        def callback(finished):
            # check if the task has been cancelled and raise an exception
            # from within. This 'strategy' can only be used with code that
            # properly cleans up after itself in the case of an exception
            # (does not leave any global locks, opened file descriptors, ...)
            if task.cancelled:
                raise KeyboardInterrupt()
            set_progress(finished * 100)

        self.progressBarInit()
        # Submit the evaluation function to the executor and fill in the
        # task with the resultant Future.
        # task.future = self._executor.submit(self.learn.fit_one_cycle(1))

        fit_model = partial(train_model, self.model, 5, self.train_loader, self.test_loader, self.device,
                            self.criterion, self.optimizer, callback=callback)

        task.future = self._executor.submit(fit_model)
        # Setup the FutureWatcher to notify us of completion
        task.watcher = FutureWatcher(task.future)
        # by using FutureWatcher we ensure `_task_finished` slot will be
        # called from the main GUI thread by the Qt's event loop
        task.watcher.done.connect(self._task_finished)

    @pyqtSlot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the result of learner evaluation.
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        self.model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in self.test_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

        # try:
        #     result = f.result()  # type: List[Results]
        # except Exception as ex:
        #     # Log the exception with a traceback
        #     log = logging.getLogger()
        #     log.exception(__name__, exc_info=True)
        #     self.error("Exception occurred during evaluation: {!r}".format(ex))
        #     # clear all results
        #     self.result= None
        # else:
        print(self.learn.validate())
            # ... and update self.results

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self._task = None

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    def conv(self, ni, nf):
        return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)

    def train(self):
        if self.learn is None:
            return
        self.learn.fit_one_cycle(3)

    @Inputs.data
    def set_data(self, data):
        if data is not None:
            self.data = data
            self.learn = Learner(self.data, self.model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy,
                                 add_time=False, bn_wd=False, silent=True)
            self.label.setText(self.learn.summary())
        else:
            self.data = None
class OWClusterAnalysis(widget.OWWidget):
    name = "Cluster Analysis"
    description = "Perform cluster analysis."
    icon = "icons/ClusterAnalysis.svg"
    priority = 2010

    class Inputs:
        data = Input("Data", Table, default=True)
        genes = Input("Genes", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        contingency = Output("Contingency Table", Table)

    N_GENES_PER_CLUSTER_MAX = 10
    N_MOST_ENRICHED_MAX = 50
    CELL_SIZES = (14, 22, 30)

    settingsHandler = DomainContextHandler(metas_in_res=True)
    cluster_var = ContextSetting(None)
    selection = ContextSetting(set())
    gene_selection = ContextSetting(0)
    differential_expression = ContextSetting(0)
    cell_size_ix = ContextSetting(2)
    _diff_exprs = ("high", "low", "either")
    n_genes_per_cluster = ContextSetting(3)
    n_most_enriched = ContextSetting(20)
    biclustering = ContextSetting(True)
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.ca = None
        self.clusters = None
        self.data = None
        self.feature_model = DomainModel(valid_types=DiscreteVariable)
        self.gene_list = None
        self.model = None
        self.pvalues = None

        self._executor = ThreadExecutor()
        self._gene_selection_history = (self.gene_selection,
                                        self.gene_selection)
        self._task = None

        box = gui.vBox(self.controlArea, "Info")
        self.infobox = gui.widgetLabel(box, self._get_info_string())

        box = gui.vBox(self.controlArea, "Cluster Variable")
        gui.comboBox(box,
                     self,
                     "cluster_var",
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._run_cluster_analysis)

        layout = QGridLayout()
        self.gene_selection_radio_group = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "gene_selection",
            orientation=layout,
            box="Gene Selection",
            callback=self._gene_selection_changed)

        def conditional_set_gene_selection(id):
            def f():
                if self.gene_selection == id:
                    return self._set_gene_selection()

            return f

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False), 1, 1)
        cb = gui.hBox(None, margin=0)
        gui.widgetLabel(cb, "Top")
        self.n_genes_per_cluster_spin = gui.spin(
            cb,
            self,
            "n_genes_per_cluster",
            minv=1,
            maxv=self.N_GENES_PER_CLUSTER_MAX,
            controlWidth=60,
            alignment=Qt.AlignRight,
            callback=conditional_set_gene_selection(0))
        gui.widgetLabel(cb, "genes per cluster")
        gui.rubber(cb)
        layout.addWidget(cb, 1, 2, Qt.AlignLeft)

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False), 2, 1)
        mb = gui.hBox(None, margin=0)
        gui.widgetLabel(mb, "Top")
        self.n_most_enriched_spin = gui.spin(
            mb,
            self,
            "n_most_enriched",
            minv=1,
            maxv=self.N_MOST_ENRICHED_MAX,
            controlWidth=60,
            alignment=Qt.AlignRight,
            callback=conditional_set_gene_selection(1))
        gui.widgetLabel(mb, "highest enrichments")
        gui.rubber(mb)
        layout.addWidget(mb, 2, 2, Qt.AlignLeft)

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False,
                                  disabled=True), 3, 1)
        sb = gui.hBox(None, margin=0)
        gui.widgetLabel(sb, "User-provided list of genes")
        gui.rubber(sb)
        layout.addWidget(sb, 3, 2)

        layout = QGridLayout()
        self.differential_expression_radio_group = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "differential_expression",
            orientation=layout,
            box="Differential Expression",
            callback=self._set_gene_selection)

        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Overexpressed in cluster",
                                  addToLayout=False), 1, 1)
        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Underexpressed in cluster",
                                  addToLayout=False), 2, 1)
        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Either",
                                  addToLayout=False), 3, 1)

        box = gui.vBox(self.controlArea, "Sorting and Zoom")
        gui.checkBox(box,
                     self,
                     "biclustering",
                     "Biclustering of analysis results",
                     callback=self._set_gene_selection)
        gui.radioButtons(box,
                         self,
                         "cell_size_ix",
                         btnLabels=("S", "M", "L"),
                         callback=lambda: self.tableview.set_cell_size(
                             self.CELL_SIZES[self.cell_size_ix]),
                         orientation=Qt.Horizontal)

        gui.rubber(self.controlArea)

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    def _get_current_gene_selection(self):
        return self._gene_selection_history[0]

    def _get_previous_gene_selection(self):
        return self._gene_selection_history[1]

    def _progress_gene_selection_history(self, new_gene_selection):
        self._gene_selection_history = (new_gene_selection,
                                        self._gene_selection_history[0])

    def _get_info_string(self):
        formatstr = "Cells: {0}\nGenes: {1}\nClusters: {2}"
        if self.data:
            return formatstr.format(len(self.data),
                                    len(self.data.domain.attributes),
                                    len(self.cluster_var.values))
        else:
            return formatstr.format(*["No input data"] * 3)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.feature_model.set_domain(None)
        self.ca = None
        self.cluster_var = None
        self.columns = None
        self.clusters = None
        self.gene_list = None
        self.model = None
        self.pvalues = None
        self.n_genes_per_cluster_spin.setMaximum(self.N_GENES_PER_CLUSTER_MAX)
        self.n_most_enriched_spin.setMaximum(self.N_MOST_ENRICHED_MAX)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.openContext(self.data)
                if self.cluster_var is None:
                    self.cluster_var = self.feature_model[0]
                self._run_cluster_analysis()
            else:
                self.tableview.clear()
        else:
            self.tableview.clear()

    @Inputs.genes
    def set_genes(self, data):
        self.Error.clear()
        gene_list_radio = self.gene_selection_radio_group.group.buttons()[2]

        if (data is None or GENE_AS_ATTRIBUTE_NAME not in data.attributes
                or not data.attributes[GENE_AS_ATTRIBUTE_NAME]
                and GENE_ID_COLUMN not in data.attributes
                or data.attributes[GENE_AS_ATTRIBUTE_NAME]
                and GENE_ID_ATTRIBUTE not in data.attributes):
            if data is not None:
                self.error(
                    "Gene annotations missing in the input data. Use Gene Name Matching widget."
                )
            self.gene_list = None
            gene_list_radio.setDisabled(True)
            if self.gene_selection == 2:
                self.gene_selection_radio_group.group.buttons()[
                    self._get_previous_gene_selection()].click()
        else:
            if data.attributes[GENE_AS_ATTRIBUTE_NAME]:
                gene_id_attribute = data.attributes.get(
                    GENE_ID_ATTRIBUTE, None)

                self.gene_list = tuple(
                    str(var.attributes[gene_id_attribute])
                    for var in data.domain.attributes
                    if gene_id_attribute in var.attributes
                    and var.attributes[gene_id_attribute] != "?")
            else:
                gene_id_column = data.attributes.get(GENE_ID_COLUMN, None)
                self.gene_list = tuple(
                    str(v) for v in data.get_column_view(gene_id_column)[0]
                    if v not in ("", "?"))
            gene_list_radio.setDisabled(False)
            if self.gene_selection == 2:
                self._set_gene_selection()
            else:
                gene_list_radio.click()

    def _run_cluster_analysis(self):
        self.infobox.setText(self._get_info_string())
        gene_count = len(self.data.domain.attributes)
        cluster_count = len(self.cluster_var.values)
        self.n_genes_per_cluster_spin.setMaximum(
            min(self.N_GENES_PER_CLUSTER_MAX, gene_count // cluster_count))
        self.n_most_enriched_spin.setMaximum(
            min(self.N_MOST_ENRICHED_MAX, gene_count))
        # TODO: what happens if error occurs? If CA fails, widget should properly handle it.
        self._start_task_init(
            partial(ClusterAnalysis, self.data, self.cluster_var.name))

    def _start_task_init(self, f):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task("init")

        def callback(finished):
            if self._task.cancelled:
                raise KeyboardInterrupt()
            self.progressBarSet(finished * 50)

        f = partial(f, callback=callback)

        self.progressBarInit()
        self._task.future = self._executor.submit(f)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._init_task_finished)

    def _start_task_gene_selection(self, f):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task("gene_selection")

        def callback(finished):
            if self._task.cancelled:
                raise KeyboardInterrupt()
            self.progressBarSet(50 + finished * 50)

        f = partial(f, callback=callback)

        self.progressBarInit()
        self.progressBarSet(50)
        self._task.future = self._executor.submit(f)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._gene_selection_task_finished)

    @Slot(concurrent.futures.Future)
    def _init_task_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        self.ca = f.result()
        self._set_gene_selection()

    @Slot(concurrent.futures.Future)
    def _gene_selection_task_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        self.clusters, genes, self.model, self.pvalues = f.result()
        genes = [str(gene) for gene in genes]
        self.columns = DiscreteVariable("Gene", genes, ordered=True)
        self.tableview.set_headers(
            self.clusters,
            self.columns.values,
            circles=True,
            cell_size=self.CELL_SIZES[self.cell_size_ix],
            bold_headers=False)

        def tooltip(i, j):
            return (
                "<b>cluster</b>: {}<br /><b>gene</b>: {}<br /><b>fraction expressing</b>: {:.2f}<br />\
                                <b>p-value</b>: {:.2e}".format(
                    self.clusters[i], self.columns.values[j], self.model[i, j],
                    self.pvalues[i, j]))

        self.tableview.update_table(self.model, tooltip=tooltip)
        self._invalidate()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            if self._task.type == "init":
                self._task.watcher.done.disconnect(self._init_task_finished)
            else:
                self._task.watcher.done.disconnect(
                    self._gene_selection_task_finished)
            self._task = None

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    def _gene_selection_changed(self):
        if self.gene_selection != self._get_current_gene_selection():
            self._progress_gene_selection_history(self.gene_selection)
            self.differential_expression_radio_group.setDisabled(
                self.gene_selection == 2)
            self._set_gene_selection()

    def _set_gene_selection(self):
        self.Warning.clear()
        if self.ca is not None and (self._task is None
                                    or self._task.type != "init"):
            if self.gene_selection == 0:
                f = partial(self.ca.enriched_genes_per_cluster,
                            self.n_genes_per_cluster)
            elif self.gene_selection == 1:
                f = partial(self.ca.enriched_genes_data, self.n_most_enriched)
            else:
                if self.data is not None and GENE_ID_ATTRIBUTE not in self.data.attributes:
                    self.error(
                        "Gene annotations missing in the input data. Use Gene Name Matching widget."
                    )
                    if self.gene_selection == 2:
                        self.gene_selection_radio_group.group.buttons()[
                            self._get_previous_gene_selection()].click()
                    return
                relevant_genes = tuple(self.ca.intersection(self.gene_list))
                if len(relevant_genes) > self.N_MOST_ENRICHED_MAX:
                    self.warning("Only first {} reference genes shown.".format(
                        self.N_MOST_ENRICHED_MAX))
                f = partial(self.ca.enriched_genes,
                            relevant_genes[:self.N_MOST_ENRICHED_MAX])
            f = partial(
                f,
                enrichment=self._diff_exprs[self.differential_expression],
                biclustering=self.biclustering)
            self._start_task_gene_selection(f)
        else:
            self._invalidate()

    def handleNewSignals(self):
        self._invalidate()

    def commit(self):
        if len(self.selection):
            cluster_ids = set()
            column_ids = set()
            for (ir, ic) in self.selection:
                cluster_ids.add(ir)
                column_ids.add(ic)
            new_domain = Domain([
                self.data.domain[self.columns.values[col]]
                for col in column_ids
            ], self.data.domain.class_vars, self.data.domain.metas)
            selected_data = Values([
                FilterDiscrete(self.cluster_var, [self.clusters[ir]])
                for ir in cluster_ids
            ],
                                   conjunction=False)(self.data)
            selected_data = selected_data.transform(new_domain)
            annotated_data = create_annotated_table(
                self.data.transform(new_domain),
                np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])
        if self.ca is not None and self._task is None:
            table = self.ca.create_contingency_table()
        else:
            table = None
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.contingency.send(table)

    def _invalidate(self):
        self.selection = self.tableview.get_selection()
        self.commit()

    def send_report(self):
        rows = None
        columns = None
        if self.data is not None:
            rows = self.cluster_var
            if rows in self.data.domain:
                rows = self.data.domain[rows]
            columns = self.columns
            if columns in self.data.domain:
                columns = self.data.domain[columns]
        self.report_items((
            ("Rows", rows),
            ("Columns", columns),
        ))
Ejemplo n.º 10
0
class OWNNLearner(OWBaseLearner):
    name = "神经网络(Neural Network)"
    description = "一种具有反向传播的多层感知器(MLP)算法。"
    icon = "icons/NN.svg"
    priority = 90
    keywords = ["mlp", 'shenjingwangluo']
    category = '模型(Model)'

    LEARNER = NNLearner

    activation = ["identity", "logistic", "tanh", "relu"]
    act_lbl = ["Identity", "Logistic", "tanh", "ReLu"]
    chinese_act_lbl = ["相等", "Logistic", "tanh", "ReLu"]
    solver = ["lbfgs", "sgd", "adam"]
    solv_lbl = ["L-BFGS-B", "SGD", "Adam"]

    learner_name = Setting("Neural Network")
    hidden_layers_input = Setting("100,")
    activation_index = Setting(3)
    solver_index = Setting(2)
    max_iterations = Setting(200)
    alpha_index = Setting(1)
    replicable = Setting(True)
    settings_version = 2

    alphas = list(
        chain([0], [x / 10000 for x in range(1, 10)],
              [x / 1000 for x in range(1, 10)],
              [x / 100 for x in range(1, 10)], [x / 10 for x in range(1, 10)],
              range(1, 10), range(10, 100, 5), range(100, 200, 10),
              range(100, 1001, 50)))

    class Warning(OWBaseLearner.Warning):
        no_layers = Msg("ANN without hidden layers is equivalent to logistic "
                        "regression with worse fitting.\nWe recommend using "
                        "logistic regression.")

    def add_main_layout(self):
        # this is part of init, pylint: disable=attribute-defined-outside-init
        form = QFormLayout()
        form.setFieldGrowthPolicy(form.AllNonFixedFieldsGrow)
        form.setLabelAlignment(Qt.AlignLeft)
        gui.widgetBox(self.controlArea, True, orientation=form)
        form.addRow(
            "隐藏层中的神经元:",
            gui.lineEdit(None,
                         self,
                         "hidden_layers_input",
                         orientation=Qt.Horizontal,
                         callback=self.settings_changed,
                         tooltip="定义神经元的整数列表。列表长度定义层数。例如4、2、2、3。",
                         placeholderText="e.g. 10,"))
        form.addRow(
            "激活:",
            gui.comboBox(None,
                         self,
                         "activation_index",
                         orientation=Qt.Horizontal,
                         label="Activation:",
                         items=[i for i in self.chinese_act_lbl],
                         callback=self.settings_changed))

        form.addRow(
            "求解器(Solver):",
            gui.comboBox(None,
                         self,
                         "solver_index",
                         orientation=Qt.Horizontal,
                         label="Solver:",
                         items=[i for i in self.solv_lbl],
                         callback=self.settings_changed))
        self.reg_label = QLabel()
        slider = gui.hSlider(None,
                             self,
                             "alpha_index",
                             minValue=0,
                             maxValue=len(self.alphas) - 1,
                             callback=lambda:
                             (self.set_alpha(), self.settings_changed()),
                             createLabel=False)
        form.addRow(self.reg_label, slider)
        self.set_alpha()

        form.addRow(
            "最大迭代次数:",
            gui.spin(None,
                     self,
                     "max_iterations",
                     10,
                     1000000,
                     step=10,
                     label="Max iterations:",
                     orientation=Qt.Horizontal,
                     alignment=Qt.AlignRight,
                     callback=self.settings_changed))

        form.addRow(
            gui.checkBox(None,
                         self,
                         "replicable",
                         label="可重复训练",
                         callback=self.settings_changed,
                         attribute=Qt.WA_LayoutUsesWidgetRect))

    def set_alpha(self):
        # called from init, pylint: disable=attribute-defined-outside-init
        self.strength_C = self.alphas[self.alpha_index]
        self.reg_label.setText("正则化, α={}:".format(self.strength_C))

    @property
    def alpha(self):
        return self.alphas[self.alpha_index]

    def setup_layout(self):
        # this is part of init, pylint: disable=attribute-defined-outside-init
        super().setup_layout()

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # just a test cancel button
        b = gui.button(self.apply_button,
                       self,
                       "取消",
                       callback=self.cancel,
                       addToLayout=False)
        self.apply_button.layout().insertStretch(0, 100)
        self.apply_button.layout().insertWidget(0, b)

    def create_learner(self):
        return self.LEARNER(hidden_layer_sizes=self.get_hidden_layers(),
                            activation=self.activation[self.activation_index],
                            solver=self.solver[self.solver_index],
                            alpha=self.alpha,
                            random_state=1 if self.replicable else None,
                            max_iter=self.max_iterations,
                            preprocessors=self.preprocessors)

    def get_learner_parameters(self):
        return (("Hidden layers", ', '.join(map(str,
                                                self.get_hidden_layers()))),
                ("Activation", self.act_lbl[self.activation_index]),
                ("Solver", self.solv_lbl[self.solver_index]),
                ("Alpha", self.alpha), ("Max iterations", self.max_iterations),
                ("Replicable training", self.replicable))

    def get_hidden_layers(self):
        self.Warning.no_layers.clear()
        layers = tuple(map(int, re.findall(r'\d+', self.hidden_layers_input)))
        if not layers:
            self.Warning.no_layers()
        return layers

    def update_model(self):
        self.show_fitting_failed(None)
        self.model = None
        if self.check_data():
            self.__update()
        else:
            self.Outputs.model.send(self.model)

    @Slot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    def __update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        max_iter = self.learner.kwargs["max_iter"]

        # Setup the task state
        task = Task()
        lastemitted = 0.

        def callback(iteration):
            nonlocal task
            nonlocal lastemitted
            if task.isInterruptionRequested():
                raise CancelTaskException()
            progress = round(iteration / max_iter * 100)
            if progress != lastemitted:
                task.emitProgressUpdate(progress)
                lastemitted = progress

        # copy to set the callback so that the learner output is not modified
        # (currently we can not pass callbacks to learners __call__)
        learner = copy.copy(self.learner)
        learner.callback = callback

        def build_model(data, learner):
            try:
                return learner(data)
            except CancelTaskException:
                return None

        build_model_func = partial(build_model, self.data, learner)

        task.setFuture(self._executor.submit(build_model_func))
        task.done.connect(self._task_finished)
        task.progressChanged.connect(self.setProgressValue)

        # set in setup_layout; pylint: disable=attribute-defined-outside-init
        self._task = task

        self.progressBarInit()
        self.setBlocking(True)

    @Slot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the built model
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()
        self._task.deleteLater()
        self._task = None  # pylint: disable=attribute-defined-outside-init
        self.setBlocking(False)
        self.progressBarFinished()

        try:
            self.model = f.result()
        except Exception as ex:  # pylint: disable=broad-except
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.model = None
            self.show_fitting_failed(ex)
        else:
            self.model.name = self.learner_name
            self.model.instances = self.data
            self.model.skl_model.orange_callback = None  # remove unpicklable callback
            self.Outputs.model.send(self.model)

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect from the task
            self._task.done.disconnect(self._task_finished)
            self._task.progressChanged.disconnect(self.setProgressValue)
            self._task.deleteLater()
            self._task = None  # pylint: disable=attribute-defined-outside-init

        self.progressBarFinished()
        self.setBlocking(False)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, settings, version):
        if not version:
            alpha = settings.pop("alpha", None)
            if alpha is not None:
                settings["alpha_index"] = \
                    np.argmin(np.abs(np.array(cls.alphas) - alpha))
        elif version < 2:
            settings["alpha_index"] = settings.get("alpha_index", 0) + 1
Ejemplo n.º 11
0
class OWSetEnrichment(widget.OWWidget):
    name = "Set Enrichment"
    description = ""
    icon = "../widgets/icons/GeneSetEnrichment.svg"
    priority = 5000

    inputs = [("Data", Orange.data.Table, "setData", widget.Default),
              ("Reference", Orange.data.Table, "setReference")]
    outputs = [("Data subset", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    taxid = settings.ContextSetting(None)
    speciesIndex = settings.ContextSetting(0)
    genesinrows = settings.ContextSetting(False)
    geneattr = settings.ContextSetting(0)
    categoriesCheckState = settings.ContextSetting({})

    useReferenceData = settings.Setting(False)
    useMinCountFilter = settings.Setting(True)
    useMaxPValFilter = settings.Setting(True)
    useMaxFDRFilter = settings.Setting(True)
    minClusterCount = settings.Setting(3)
    maxPValue = settings.Setting(0.01)
    maxFDR = settings.Setting(0.01)
    autocommit = settings.Setting(False)

    Ready, Initializing, Loading, RunningEnrichment = 0, 1, 2, 4

    def __init__(self, parent=None):
        super().__init__(parent)

        self.geneMatcherSettings = [False, False, True, False]

        self.data = None
        self.referenceData = None
        self.taxid_list = []

        self.__genematcher = (None, fulfill(gene.matcher([])))
        self.__invalidated = False

        self.currentAnnotatedCategories = []
        self.state = None
        self.__state = OWSetEnrichment.Initializing

        box = gui.widgetBox(self.controlArea, "Info")
        self.infoBox = gui.widgetLabel(box, "Info")
        self.infoBox.setText("No data on input.\n")

        self.speciesComboBox = gui.comboBox(
            self.controlArea, self,
            "speciesIndex", "Species",
            callback=self.__on_speciesIndexChanged)

        box = gui.widgetBox(self.controlArea, "Entity names")
        self.geneAttrComboBox = gui.comboBox(
            box, self, "geneattr", "Entity feature", sendSelectedValue=0,
            callback=self.updateAnnotations)

        cb = gui.checkBox(
            box, self, "genesinrows", "Use feature names",
            callback=self.updateAnnotations,
            disables=[(-1, self.geneAttrComboBox)])
        cb.makeConsistent()

#         gui.button(box, self, "Gene matcher settings",
#                    callback=self.updateGeneMatcherSettings,
#                    tooltip="Open gene matching settings dialog")

        self.referenceRadioBox = gui.radioButtonsInBox(
            self.controlArea,
            self, "useReferenceData",
            ["All entities", "Reference set (input)"],
            tooltips=["Use entire genome (for gene set enrichment) or all " +
                      "available entities for reference",
                      "Use entities from Reference Examples input signal " +
                      "as reference"],
            box="Reference", callback=self.updateAnnotations)

        box = gui.widgetBox(self.controlArea, "Entity Sets")
        self.groupsWidget = QtGui.QTreeWidget(self)
        self.groupsWidget.setHeaderLabels(["Category"])
        box.layout().addWidget(self.groupsWidget)

        hLayout = QtGui.QHBoxLayout()
        hLayout.setSpacing(10)
        hWidget = gui.widgetBox(self.mainArea, orientation=hLayout)
        gui.spin(hWidget, self, "minClusterCount",
                 0, 100, label="Entities",
                 tooltip="Minimum entity count",
                 callback=self.filterAnnotationsChartView,
                 callbackOnReturn=True,
                 checked="useMinCountFilter",
                 checkCallback=self.filterAnnotationsChartView)

        pvalfilterbox = gui.widgetBox(hWidget, orientation="horizontal")
        cb = gui.checkBox(
            pvalfilterbox, self, "useMaxPValFilter", "p-value",
            callback=self.filterAnnotationsChartView)

        sp = gui.doubleSpin(
            pvalfilterbox, self, "maxPValue", 0.0, 1.0, 0.0001,
            tooltip="Maximum p-value",
            callback=self.filterAnnotationsChartView,
            callbackOnReturn=True,
        )
        sp.setEnabled(self.useMaxFDRFilter)
        cb.toggled[bool].connect(sp.setEnabled)

        pvalfilterbox.layout().setAlignment(cb, Qt.AlignRight)
        pvalfilterbox.layout().setAlignment(sp, Qt.AlignLeft)

        fdrfilterbox = gui.widgetBox(hWidget, orientation="horizontal")
        cb = gui.checkBox(
            fdrfilterbox, self, "useMaxFDRFilter", "FDR",
            callback=self.filterAnnotationsChartView)

        sp = gui.doubleSpin(
            fdrfilterbox, self, "maxFDR", 0.0, 1.0, 0.0001,
            tooltip="Maximum False discovery rate",
            callback=self.filterAnnotationsChartView,
            callbackOnReturn=True,
        )
        sp.setEnabled(self.useMaxFDRFilter)
        cb.toggled[bool].connect(sp.setEnabled)

        fdrfilterbox.layout().setAlignment(cb, Qt.AlignRight)
        fdrfilterbox.layout().setAlignment(sp, Qt.AlignLeft)

        self.filterLineEdit = QtGui.QLineEdit(
            self, placeholderText="Filter ...")

        self.filterCompleter = QtGui.QCompleter(self.filterLineEdit)
        self.filterCompleter.setCaseSensitivity(Qt.CaseInsensitive)
        self.filterLineEdit.setCompleter(self.filterCompleter)

        hLayout.addWidget(self.filterLineEdit)
        self.mainArea.layout().addWidget(hWidget)

        self.filterLineEdit.textChanged.connect(
            self.filterAnnotationsChartView)

        self.annotationsChartView = QtGui.QTreeView(
            alternatingRowColors=True,
            sortingEnabled=True,
            selectionMode=QtGui.QTreeView.ExtendedSelection,
            rootIsDecorated=False,
            editTriggers=QtGui.QTreeView.NoEditTriggers,
        )
        self.annotationsChartView.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.annotationsChartView)

        contextEventFilter = gui.VisibleHeaderSectionContextEventFilter(
            self.annotationsChartView)
        self.annotationsChartView.header().installEventFilter(contextEventFilter)

        self.groupsWidget.itemClicked.connect(self.subsetSelectionChanged)
        gui.auto_commit(self.controlArea, self, "autocommit", "Commit")

        self.setBlocking(True)

        task = EnsureDownloaded(
            [("Taxonomy", "ncbi_taxonomy.tar.gz"),
             (geneset.sfdomain, "index.pck")]
        )

        task.finished.connect(self.__initialize_finish)
        self.setStatusMessage("Initializing")
        self._executor = ThreadExecutor(
            parent=self, threadPool=QtCore.QThreadPool(self))
        self._executor.submit(task)

    def sizeHint(self):
        return QtCore.QSize(1024, 600)

    def __initialize_finish(self):
        # Finalize the the widget's initialization (preferably after
        # ensuring all required databases have been downloaded.

        sets = geneset.list_all()
        taxids = set(taxonomy.common_taxids() +
                     list(filter(None, [tid for _, tid, _ in sets])))
        organisms = [(tid, name_or_none(tid)) for tid in taxids]
        organisms = [(tid, name) for tid, name in organisms
                     if name is not None]

        organisms = [(None, "None")] + sorted(organisms)
        taxids = [tid for tid, _ in organisms]
        names = [name for _, name in organisms]
        self.taxid_list = taxids

        self.speciesComboBox.clear()
        self.speciesComboBox.addItems(names)
        self.genesets = sets

        if self.taxid in self.taxid_list:
            taxid = self.taxid
        else:
            taxid = self.taxid_list[0]

        self.taxid = None
        self.setCurrentOrganism(taxid)
        self.setBlocking(False)
        self.__state = OWSetEnrichment.Ready
        self.setStatusMessage("")

    def setCurrentOrganism(self, taxid):
        """Set the current organism `taxid`."""
        if taxid not in self.taxid_list:
            taxid = self.taxid_list[min(self.speciesIndex,
                                        len(self.taxid_list) - 1)]
        if self.taxid != taxid:
            self.taxid = taxid
            self.speciesIndex = self.taxid_list.index(taxid)
            self.refreshHierarchy()
            self._invalidateGeneMatcher()
            self._invalidate()

    def currentOrganism(self):
        """Return the current organism taxid"""
        return self.taxid

    def __on_speciesIndexChanged(self):
        taxid = self.taxid_list[self.speciesIndex]
        self.taxid = "< Do not look >"
        self.setCurrentOrganism(taxid)
        if self.__invalidated and self.data is not None:
            self.updateAnnotations()

    def clear(self):
        """Clear/reset the widget state."""
        self._cancelPending()
        self.state = None

        self.__state = self.__state & ~OWSetEnrichment.RunningEnrichment

        self._clearView()

        if self.annotationsChartView.model() is not None:
            self.annotationsChartView.model().clear()

        self.geneAttrComboBox.clear()
        self.geneAttrs = []
        self._updatesummary()

    def _cancelPending(self):
        """Cancel pending tasks."""
        if self.state is not None:
            self.state.results.cancel()
            self.state.namematcher.cancel()
            self.state.cancelled = True

    def _clearView(self):
        """Clear the enrichment report view (main area)."""
        if self.annotationsChartView.model() is not None:
            self.annotationsChartView.model().clear()

    def setData(self, data=None):
        """Set the input dataset with query gene names"""
        if self.__state & OWSetEnrichment.Initializing:
            self.__initialize_finish()

        self.error(0)
        self.closeContext()
        self.clear()

        self.groupsWidget.clear()
        self.data = data

        if data is not None:
            varlist = [var for var in data.domain.variables + data.domain.metas
                       if isinstance(var, Orange.data.StringVariable)]

            self.geneAttrs = varlist
            for var in varlist:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            oldtaxid = self.taxid
            self.geneattr = min(self.geneattr, len(self.geneAttrs) - 1)

            taxid = data_hints.get_hint(data, "taxid", "")
            if taxid in self.taxid_list:
                self.speciesIndex = self.taxid_list.index(taxid)
                self.taxid = taxid

            self.genesinrows = data_hints.get_hint(
                data, "genesinrows", self.genesinrows)

            self.openContext(data)
            if oldtaxid != self.taxid:
                self.taxid = "< Do not look >"
                self.setCurrentOrganism(taxid)

            self.refreshHierarchy()
            self._invalidate()

    def setReference(self, data=None):
        """Set the (optional) input dataset with reference gene names."""
        self.referenceData = data
        self.referenceRadioBox.setEnabled(bool(data))
        if self.useReferenceData:
            self._invalidate()

    def handleNewSignals(self):
        if self.__invalidated:
            self.updateAnnotations()

    def _invalidateGeneMatcher(self):
        _, f = self.__genematcher
        f.cancel()
        self.__genematcher = (None, fulfill(gene.matcher([])))

    def _invalidate(self):
        self.__invalidated = True

    def genesFromTable(self, table):
        if self.genesinrows:
            genes = [attr.name for attr in table.domain.attributes]
        else:
            geneattr = self.geneAttrs[self.geneattr]
            genes = [str(ex[geneattr]) for ex in table]
        return genes

    def getHierarchy(self, taxid):
        def recursive_dict():
            return defaultdict(recursive_dict)
        collection = recursive_dict()

        def collect(col, hier):
            if hier:
                collect(col[hier[0]], hier[1:])

        for hierarchy, t_id, _ in self.genesets:
            collect(collection[t_id], hierarchy)

        return (taxid, collection[taxid]), (None, collection[None])

    def setHierarchy(self, hierarchy, hierarchy_noorg):
        self.groupsWidgetItems = {}

        def fill(col, parent, full=(), org=""):
            for key, value in sorted(col.items()):
                full_cat = full + (key,)
                item = QtGui.QTreeWidgetItem(parent, [key])
                item.setFlags(item.flags() | Qt.ItemIsUserCheckable |
                              Qt.ItemIsSelectable | Qt.ItemIsEnabled)
                if value:
                    item.setFlags(item.flags() | Qt.ItemIsTristate)

                checked = self.categoriesCheckState.get(
                    (full_cat, org), Qt.Checked)
                item.setData(0, Qt.CheckStateRole, checked)
                item.setExpanded(True)
                item.category = full_cat
                item.organism = org
                self.groupsWidgetItems[full_cat] = item
                fill(value, item, full_cat, org=org)

        self.groupsWidget.clear()
        fill(hierarchy[1], self.groupsWidget, org=hierarchy[0])
        fill(hierarchy_noorg[1], self.groupsWidget, org=hierarchy_noorg[0])

    def refreshHierarchy(self):
        self.setHierarchy(*self.getHierarchy(taxid=self.taxid_list[self.speciesIndex]))

    def selectedCategories(self):
        """
        Return a list of currently selected hierarchy keys.

        A key is a tuple of identifiers from the root to the leaf of
        the hierarchy tree.
        """
        return [key for key, check in self.getHierarchyCheckState().items()
                if check == Qt.Checked]

    def getHierarchyCheckState(self):
        def collect(item, full=()):
            checked = item.checkState(0)
            name = str(item.data(0, Qt.DisplayRole))
            full_cat = full + (name,)
            result = [((full_cat, item.organism), checked)]
            for i in range(item.childCount()):
                result.extend(collect(item.child(i), full_cat))
            return result

        items = [self.groupsWidget.topLevelItem(i)
                 for i in range(self.groupsWidget.topLevelItemCount())]
        states = itertools.chain(*(collect(item) for item in items))
        return dict(states)

    def subsetSelectionChanged(self, item, column):
        # The selected geneset (hierarchy) subset has been changed by the
        # user. Update the displayed results.
        # Update the stored state (persistent settings)
        self.categoriesCheckState = self.getHierarchyCheckState()
        categories = self.selectedCategories()

        if self.data is not None:
            if self._nogenematching() or \
                    not set(categories) <= set(self.currentAnnotatedCategories):
                self.updateAnnotations()
            else:
                self.filterAnnotationsChartView()

    def updateGeneMatcherSettings(self):
        raise NotImplementedError

        from .OWGOEnrichmentAnalysis import GeneMatcherDialog
        dialog = GeneMatcherDialog(self, defaults=self.geneMatcherSettings, enabled=[True] * 4, modal=True)
        if dialog.exec_():
            self.geneMatcherSettings = [getattr(dialog, item[0]) for item in dialog.items]
            self._invalidateGeneMatcher()
            if self.data is not None:
                self.updateAnnotations()

    def _genematcher(self):
        """
        Return a Future[gene.SequenceMatcher]
        """
        taxid = self.taxid_list[self.speciesIndex]

        current, matcher_f = self.__genematcher

        if taxid == current and \
                not matcher_f.cancelled():
            return matcher_f

        self._invalidateGeneMatcher()

        if taxid is None:
            self.__genematcher = (None, fulfill(gene.matcher([])))
            return self.__genematcher[1]

        matchers = [gene.GMGO, gene.GMKEGG, gene.GMNCBI, gene.GMAffy]
        matchers = [m for m, use in zip(matchers, self.geneMatcherSettings)
                    if use]

        def create():
            return gene.matcher([m(taxid) for m in matchers])

        matcher_f = self._executor.submit(create)
        self.__genematcher = (taxid, matcher_f)
        return self.__genematcher[1]

    def _nogenematching(self):
        return self.taxid is None or not any(self.geneMatcherSettings)

    def updateAnnotations(self):
        if self.data is None:
            return

        assert not self.__state & OWSetEnrichment.Initializing
        self._cancelPending()
        self._clearView()

        self.information(0)
        self.warning(0)
        self.error(0)

        if not self.genesinrows and len(self.geneAttrs) == 0:
            self.error(0, "Input data contains no attributes with gene names")
            return

        self.__state = OWSetEnrichment.RunningEnrichment

        taxid = self.taxid_list[self.speciesIndex]
        self.taxid = taxid

        categories = self.selectedCategories()

        clusterGenes = self.genesFromTable(self.data)

        if self.referenceData is not None and self.useReferenceData:
            referenceGenes = self.genesFromTable(self.referenceData)
        else:
            referenceGenes = None

        self.currentAnnotatedCategories = categories

        genematcher = self._genematcher()

        self.progressBarInit()

        ## Load collections in a worker thread
        # TODO: Use cached collections if already loaded and
        # use ensure_genesetsdownloaded with progress report (OWSelectGenes)
        collections = self._executor.submit(geneset.collections, *categories)

        def refset_null():
            """Return the default background reference set"""
            col = collections.result()
            return reduce(operator.ior, (set(g.genes) for g in col), set())

        def refset_ncbi():
            """Return all NCBI gene names"""
            geneinfo = gene.NCBIGeneInfo(taxid)
            return set(geneinfo.keys())

        def namematcher():
            matcher = genematcher.result()
            match = matcher.set_targets(ref_set.result())
            match.umatch = memoize(match.umatch)
            return match

        def map_unames():
            matcher = namematcher.result()
            query = list(filter(None, map(matcher.umatch, querynames)))
            reference = list(filter(None, map(matcher.umatch, ref_set.result())))
            return query, reference

        if self._nogenematching():
            if referenceGenes is None:
                ref_set = self._executor.submit(refset_null)
            else:
                ref_set = fulfill(referenceGenes)
        else:
            if referenceGenes == None:
                ref_set = self._executor.submit(refset_ncbi)
            else:
                ref_set = fulfill(referenceGenes)

        namematcher = self._executor.submit(namematcher)
        querynames = clusterGenes

        state = types.SimpleNamespace()
        state.query_set = clusterGenes
        state.reference_set = referenceGenes
        state.namematcher = namematcher
        state.query_count = len(set(clusterGenes))
        state.reference_count = (len(set(referenceGenes))
                                 if referenceGenes is not None else None)

        state.cancelled = False

        progress = methodinvoke(self, "_setProgress", (float,))
        info = methodinvoke(self, "_setRunInfo", (str,))

        @withtraceback
        def run():
            info("Loading data")
            match = namematcher.result()
            query, reference = map_unames()
            gscollections = collections.result()

            results = []
            info("Running enrichment")
            p = 0
            for i, gset in enumerate(gscollections):
                genes = set(filter(None, map(match.umatch, gset.genes)))
                enr = set_enrichment(genes, reference, query)
                results.append((gset, enr))

                if state.cancelled:
                    raise UserInteruptException

                pnew = int(100 * i / len(gscollections))
                if pnew != p:
                    progress(pnew)
                    p = pnew
            progress(100)
            info("")
            return query, reference, results

        task = Task(function=run)
        task.resultReady.connect(self.__on_enrichment_finished)
        task.exceptionReady.connect(self.__on_enrichment_failed)
        result = self._executor.submit(task)
        state.results = result

        self.state = state
        self._updatesummary()

    def __on_enrichment_failed(self, exception):
        if not isinstance(exception, UserInteruptException):
            print("ERROR:", exception, file=sys.stderr)
            print(exception._traceback, file=sys.stderr)

        self.progressBarFinished()
        self.setStatusMessage("")
        self.__state &= ~OWSetEnrichment.RunningEnrichment

    def __on_enrichment_finished(self, results):
        assert QThread.currentThread() is self.thread()
        self.__state &= ~OWSetEnrichment.RunningEnrichment

        query, reference, results = results

        if self.annotationsChartView.model():
            self.annotationsChartView.model().clear()

        nquery = len(query)
        nref = len(reference)
        maxcount = max((len(e.query_mapped) for _, e in results),
                       default=1)
        maxrefcount = max((len(e.reference_mapped) for _, e in results),
                          default=1)
        nspaces = int(math.ceil(math.log10(maxcount or 1)))
        refspaces = int(math.ceil(math.log(maxrefcount or 1)))
        query_fmt = "%" + str(nspaces) + "s  (%.2f%%)"
        ref_fmt = "%" + str(refspaces) + "s  (%.2f%%)"

        def fmt_count(fmt, count, total):
            return fmt % (count, 100.0 * count / (total or 1))

        fmt_query_count = partial(fmt_count, query_fmt)
        fmt_ref_count = partial(fmt_count, ref_fmt)

        linkFont = QtGui.QFont(self.annotationsChartView.viewOptions().font)
        linkFont.setUnderline(True)

        def item(value=None, tooltip=None, user=None):
            si = QtGui.QStandardItem()
            if value is not None:
                si.setData(value, Qt.DisplayRole)
            if tooltip is not None:
                si.setData(tooltip, Qt.ToolTipRole)
            if user is not None:
                si.setData(user, Qt.UserRole)
            else:
                si.setData(value, Qt.UserRole)
            return si

        model = QtGui.QStandardItemModel()
        model.setSortRole(Qt.UserRole)
        model.setHorizontalHeaderLabels(
            ["Category", "Term", "Count", "Reference count", "p-value",
             "FDR", "Enrichment"])
        for i, (gset, enrich) in enumerate(results):
            if len(enrich.query_mapped) == 0:
                continue
            nquery_mapped = len(enrich.query_mapped)
            nref_mapped = len(enrich.reference_mapped)

            row = [
                item(", ".join(gset.hierarchy)),
                item(gsname(gset), tooltip=gset.link),
                item(fmt_query_count(nquery_mapped, nquery),
                     tooltip=nquery_mapped, user=nquery_mapped),
                item(fmt_ref_count(nref_mapped, nref),
                     tooltip=nref_mapped, user=nref_mapped),
                item(fmtp(enrich.p_value), user=enrich.p_value),
                item(),  # column 5, FDR, is computed in filterAnnotationsChartView
                item(enrich.enrichment_score,
                     tooltip="%.3f" % enrich.enrichment_score,
                     user=enrich.enrichment_score)
            ]
            row[0].geneset = gset
            row[0].enrichment = enrich
            row[1].setData(gset.link, gui.LinkRole)
            row[1].setFont(linkFont)
            row[1].setForeground(QtGui.QColor(Qt.blue))

            model.appendRow(row)

        self.annotationsChartView.setModel(model)
        self.annotationsChartView.selectionModel().selectionChanged.connect(
            self.commit
        )

        if not model.rowCount():
            self.warning(0, "No enriched sets found.")
        else:
            self.warning(0)

        allnames = set(gsname(geneset)
                       for geneset, (count, _, _, _) in results if count)

        allnames |= reduce(operator.ior,
                           (set(word_split(name)) for name in allnames),
                           set())

        self.filterCompleter.setModel(None)
        self.completerModel = QtGui.QStringListModel(sorted(allnames))
        self.filterCompleter.setModel(self.completerModel)

        if results:
            max_score = max((e.enrichment_score for _, e in results
                             if np.isfinite(e.enrichment_score)),
                            default=1)

            self.annotationsChartView.setItemDelegateForColumn(
                6, BarItemDelegate(self, scale=(0.0, max_score))
            )

        self.annotationsChartView.setItemDelegateForColumn(
            1, gui.LinkStyledItemDelegate(self.annotationsChartView)
        )

        header = self.annotationsChartView.header()
        for i in range(model.columnCount()):
            sh = self.annotationsChartView.sizeHintForColumn(i)
            sh = max(sh, header.sectionSizeHint(i))
            self.annotationsChartView.setColumnWidth(i, max(min(sh, 300), 30))
#             self.annotationsChartView.resizeColumnToContents(i)

        self.filterAnnotationsChartView()

        self.progressBarFinished()
        self.setStatusMessage("")

    def _updatesummary(self):
        state = self.state
        if state is None:
            self.error(0,)
            self.warning(0)
            self.infoBox.setText("No data on input.\n")
            return

        text = "{.query_count} unique names on input\n".format(state)

        if state.results.done() and not state.results.exception():
            mapped, _, _ = state.results.result()
            ratio_mapped = (len(mapped) / state.query_count
                            if state.query_count else 0)
            text += ("%i (%.1f%%) gene names matched" %
                     (len(mapped), 100.0 * ratio_mapped))
        elif not state.results.done():
            text += "..."
        else:
            text += "<Error {}>".format(str(state.results.exception()))
        self.infoBox.setText(text)

        # TODO: warn on no enriched sets found (i.e no query genes
        # mapped to any set)

    def filterAnnotationsChartView(self, filterString=""):
        if self.__state & OWSetEnrichment.RunningEnrichment:
            return

        # TODO: Move filtering to a filter proxy model
        # TODO: Re-enable string search

        categories = set(", ".join(cat)
                         for cat, _ in self.selectedCategories())

#         filterString = str(self.filterLineEdit.text()).lower()

        model = self.annotationsChartView.model()

        def ishidden(index):
            # Is item at index (row) hidden
            item = model.item(index)
            item_cat = item.data(Qt.DisplayRole)
            return item_cat not in categories

        hidemask = [ishidden(i) for i in range(model.rowCount())]

        # compute FDR according the selected categories
        pvals = [model.item(i, 4).data(Qt.UserRole)
                 for i, hidden in enumerate(hidemask) if not hidden]
        fdrs = utils.stats.FDR(pvals)

        # update FDR for the selected collections and apply filtering rules
        itemsHidden = []
        fdriter = iter(fdrs)
        for index, hidden in enumerate(hidemask):
            if not hidden:
                fdr = next(fdriter)
                pval = model.index(index, 4).data(Qt.UserRole)
                count = model.index(index, 2).data(Qt.ToolTipRole)

                hidden = (self.useMinCountFilter and count < self.minClusterCount) or \
                         (self.useMaxPValFilter and pval > self.maxPValue) or \
                         (self.useMaxFDRFilter and fdr > self.maxFDR)

                if not hidden:
                    fdr_item = model.item(index, 5)
                    fdr_item.setData(fmtpdet(fdr), Qt.ToolTipRole)
                    fdr_item.setData(fmtp(fdr), Qt.DisplayRole)
                    fdr_item.setData(fdr, Qt.UserRole)

            self.annotationsChartView.setRowHidden(
                index, QModelIndex(), hidden)

            itemsHidden.append(hidden)

        if model.rowCount() and all(itemsHidden):
            self.information(0, "All sets were filtered out.")
        else:
            self.information(0)

        self._updatesummary()

    @Slot(float)
    def _setProgress(self, value):
        assert QThread.currentThread() is self.thread()
        self.progressBarSet(value, processEvents=None)

    @Slot(str)
    def _setRunInfo(self, text):
        self.setStatusMessage(text)

    def commit(self):
        if self.data is None or \
                self.__state & OWSetEnrichment.RunningEnrichment:
            return

        model = self.annotationsChartView.model()
        rows = self.annotationsChartView.selectionModel().selectedRows(0)
        selected = [model.item(index.row(), 0) for index in rows]
        mapped = reduce(operator.ior,
                        (item.enrichment.query_mapped for item in selected),
                        set())
        assert self.state.namematcher.done()
        matcher = self.state.namematcher.result()

        axis = 1 if self.genesinrows else 0
        if axis == 1:
            mapped = [attr for attr in self.data.domain.attributes
                      if matcher.umatch(attr.name) in mapped]

            newdomain = Orange.data.Domain(
                mapped, self.data.domain.class_vars, self.data.domain.metas)
            data = self.data.from_table(newdomain, self.data)
        else:
            geneattr = self.geneAttrs[self.geneattr]
            selected = [i for i, ex in enumerate(self.data)
                        if matcher.umatch(str(ex[geneattr])) in mapped]
            data = self.data[selected]
        self.send("Data subset", data)

    def onDeleteWidget(self):
        if self.state is not None:
            self._cancelPending()
            self.state = None
        self._executor.shutdown(wait=False)
class OWImportImages(widget.OWWidget):
    name = "Import Images"
    description = "Import images from a directory(s)"
    icon = "icons/ImportImages.svg"
    priority = 110

    outputs = [("Data", Orange.data.Table)]

    #: list of recent paths
    recent_paths = settings.Setting([])  # type: List[RecentPath]
    currentPath = settings.Setting(None)

    want_main_area = False
    resizing_enabled = False

    Modality = Qt.ApplicationModal
    # Modality = Qt.WindowModal

    MaxRecentItems = 20

    def __init__(self):
        super().__init__()
        #: widget's runtime state
        self.__state = State.NoState
        self._imageMeta = []
        self._imageCategories = {}

        self.__invalidated = False
        self.__pendingTask = None

        vbox = gui.vBox(self.controlArea)
        hbox = gui.hBox(vbox)
        self.recent_cb = QComboBox(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=16,
        )
        self.recent_cb.activated[int].connect(self.__onRecentActivated)
        icons = standard_icons(self)

        browseaction = QAction(
            "Open/Load Images", self,
            iconText="\N{HORIZONTAL ELLIPSIS}",
            icon=icons.dir_open_icon,
            toolTip="Select a directory from which to load the images"
        )
        browseaction.triggered.connect(self.__runOpenDialog)
        reloadaction = QAction(
            "Reload", self,
            icon=icons.reload_icon,
            toolTip="Reload current image set"
        )
        reloadaction.triggered.connect(self.reload)
        self.__actions = namespace(
            browse=browseaction,
            reload=reloadaction,
        )

        browsebutton = QPushButton(
            browseaction.iconText(),
            icon=browseaction.icon(),
            toolTip=browseaction.toolTip(),
            clicked=browseaction.trigger
        )
        reloadbutton = QPushButton(
            reloadaction.iconText(),
            icon=reloadaction.icon(),
            clicked=reloadaction.trigger,
            default=True,
        )

        hbox.layout().addWidget(self.recent_cb)
        hbox.layout().addWidget(browsebutton)
        hbox.layout().addWidget(reloadbutton)

        self.addActions([browseaction, reloadaction])

        reloadaction.changed.connect(
            lambda: reloadbutton.setEnabled(reloadaction.isEnabled())
        )
        box = gui.vBox(vbox, "Info")
        self.infostack = QStackedWidget()

        self.info_area = QLabel(
            text="No image set selected",
            wordWrap=True
        )
        self.progress_widget = QProgressBar(
            minimum=0, maximum=0
        )
        self.cancel_button = QPushButton(
            "Cancel", icon=icons.cancel_icon,
        )
        self.cancel_button.clicked.connect(self.cancel)

        w = QWidget()
        vlayout = QVBoxLayout()
        vlayout.setContentsMargins(0, 0, 0, 0)
        hlayout = QHBoxLayout()
        hlayout.setContentsMargins(0, 0, 0, 0)

        hlayout.addWidget(self.progress_widget)
        hlayout.addWidget(self.cancel_button)
        vlayout.addLayout(hlayout)

        self.pathlabel = TextLabel()
        self.pathlabel.setTextElideMode(Qt.ElideMiddle)
        self.pathlabel.setAttribute(Qt.WA_MacSmallSize)

        vlayout.addWidget(self.pathlabel)
        w.setLayout(vlayout)

        self.infostack.addWidget(self.info_area)
        self.infostack.addWidget(w)

        box.layout().addWidget(self.infostack)

        self.__initRecentItemsModel()
        self.__invalidated = True
        self.__executor = ThreadExecutor(self)

        QApplication.postEvent(self, QEvent(RuntimeEvent.Init))

    def __initRecentItemsModel(self):
        if self.currentPath is not None and \
                not os.path.isdir(self.currentPath):
            self.currentPath = None

        recent_paths = []
        for item in self.recent_paths:
            if os.path.isdir(item.abspath):
                recent_paths.append(item)
        recent_paths = recent_paths[:OWImportImages.MaxRecentItems]
        recent_model = self.recent_cb.model()
        for pathitem in recent_paths:
            item = RecentPath_asqstandarditem(pathitem)
            recent_model.appendRow(item)

        self.recent_paths = recent_paths

        if self.currentPath is not None and \
                os.path.isdir(self.currentPath) and self.recent_paths and \
                os.path.samefile(self.currentPath, self.recent_paths[0].abspath):
            self.recent_cb.setCurrentIndex(0)
        else:
            self.currentPath = None
            self.recent_cb.setCurrentIndex(-1)
        self.__actions.reload.setEnabled(self.currentPath is not None)

    def customEvent(self, event):
        """Reimplemented."""
        if event.type() == RuntimeEvent.Init:
            if self.__invalidated:
                try:
                    self.start()
                finally:
                    self.__invalidated = False

        super().customEvent(event)

    def __runOpenDialog(self):
        startdir = os.path.expanduser("~/")
        if self.recent_paths:
            startdir = self.recent_paths[0].abspath

        if OWImportImages.Modality == Qt.WindowModal:
            dlg = QFileDialog(
                self, "Select Top Level Directory", startdir,
                acceptMode=QFileDialog.AcceptOpen,
                modal=True,
            )
            dlg.setFileMode(QFileDialog.Directory)
            dlg.setOption(QFileDialog.ShowDirsOnly)
            dlg.setDirectory(startdir)
            dlg.setAttribute(Qt.WA_DeleteOnClose)

            @dlg.accepted.connect
            def on_accepted():
                dirpath = dlg.selectedFiles()
                if dirpath:
                    self.setCurrentPath(dirpath[0])
                    self.start()
            dlg.open()
        else:
            dirpath = QFileDialog.getExistingDirectory(
                self, "Select Top Level Directory", startdir
            )
            if dirpath:
                self.setCurrentPath(dirpath)
                self.start()

    def __onRecentActivated(self, index):
        item = self.recent_cb.itemData(index)
        if item is None:
            return
        assert isinstance(item, RecentPath)
        self.setCurrentPath(item.abspath)
        self.start()

    def __updateInfo(self):
        if self.__state == State.NoState:
            text = "No image set selected"
        elif self.__state == State.Processing:
            text = "Processing"
        elif self.__state == State.Done:
            nvalid = sum(imeta.isvalid for imeta in self._imageMeta)
            ncategories = len(self._imageCategories)
            if ncategories < 2:
                text = "{} images".format(nvalid)
            else:
                text = "{} images / {} categories".format(nvalid, ncategories)
        elif self.__state == State.Cancelled:
            text = "Cancelled"
        elif self.__state == State.Error:
            text = "Error state"
        else:
            assert False

        self.info_area.setText(text)

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

    def setCurrentPath(self, path):
        """
        Set the current root image path to path

        If the path does not exists or is not a directory the current path
        is left unchanged

        Parameters
        ----------
        path : str
            New root import path.

        Returns
        -------
        status : bool
            True if the current root import path was successfully
            changed to path.
        """
        if self.currentPath is not None and path is not None and \
                os.path.isdir(self.currentPath) and os.path.isdir(path) and \
                os.path.samefile(self.currentPath, path):
            return True

        if not os.path.exists(path):
            warnings.warn("'{}' does not exist".format(path), UserWarning)
            return False
        elif not os.path.isdir(path):
            warnings.warn("'{}' is not a directory".format(path), UserWarning)
            return False

        newindex = self.addRecentPath(path)
        self.recent_cb.setCurrentIndex(newindex)
        if newindex >= 0:
            self.currentPath = path
        else:
            self.currentPath = None
        self.__actions.reload.setEnabled(self.currentPath is not None)

        if self.__state == State.Processing:
            self.cancel()

        return True

    def addRecentPath(self, path):
        """
        Prepend a path entry to the list of recent paths

        If an entry with the same path already exists in the recent path
        list it is moved to the first place

        Parameters
        ----------
        path : str
        """
        existing = None
        for pathitem in self.recent_paths:
            if os.path.samefile(pathitem.abspath, path):
                existing = pathitem
                break

        model = self.recent_cb.model()

        if existing is not None:
            selected_index = self.recent_paths.index(existing)
            assert model.item(selected_index).data(Qt.UserRole) is existing
            self.recent_paths.remove(existing)
            row = model.takeRow(selected_index)
            self.recent_paths.insert(0, existing)
            model.insertRow(0, row)
        else:
            item = RecentPath(path, None, None)
            self.recent_paths.insert(0, item)
            model.insertRow(0, RecentPath_asqstandarditem(item))
        return 0

    def __setRuntimeState(self, state):
        assert state in State
        self.setBlocking(state == State.Processing)
        message = ""
        if state == State.Processing:
            assert self.__state in [State.Done,
                                    State.NoState,
                                    State.Error,
                                    State.Cancelled]
            message = "Processing"
        elif state == State.Done:
            assert self.__state == State.Processing
        elif state == State.Cancelled:
            assert self.__state == State.Processing
            message = "Cancelled"
        elif state == State.Error:
            message = "Error during processing"
        elif state == State.NoState:
            message = ""
        else:
            assert False

        self.__state = state

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

        self.setStatusMessage(message)
        self.__updateInfo()

    def reload(self):
        """
        Restart the image scan task
        """
        if self.__state == State.Processing:
            self.cancel()

        self._imageMeta = []
        self._imageCategories = {}
        self.start()

    def start(self):
        """
        Start/execute the image indexing operation
        """
        self.error()

        self.__invalidated = False
        if self.currentPath is None:
            return

        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            log.info("Starting a new task while one is in progress. "
                     "Cancel the existing task (dir:'{}')"
                     .format(self.__pendingTask.startdir))
            self.cancel()

        startdir = self.currentPath

        self.__setRuntimeState(State.Processing)

        report_progress = methodinvoke(
            self, "__onReportProgress", (object,))

        task = ImageScan(startdir, report_progress=report_progress)

        # collect the task state in one convenient place
        self.__pendingTask = taskstate = namespace(
            task=task,
            startdir=startdir,
            future=None,
            watcher=None,
            cancelled=False,
            cancel=None,
        )

        def cancel():
            # Cancel the task and disconnect
            if taskstate.future.cancel():
                pass
            else:
                taskstate.task.cancelled = True
                taskstate.cancelled = True
                try:
                    taskstate.future.result(timeout=3)
                except UserInterruptError:
                    pass
                except TimeoutError:
                    log.info("The task did not stop in in a timely manner")
            taskstate.watcher.finished.disconnect(self.__onRunFinished)

        taskstate.cancel = cancel

        def run_image_scan_task_interupt():
            try:
                return task.run()
            except UserInterruptError:
                # Suppress interrupt errors, so they are not logged
                return

        taskstate.future = self.__executor.submit(run_image_scan_task_interupt)
        taskstate.watcher = FutureWatcher(taskstate.future)
        taskstate.watcher.finished.connect(self.__onRunFinished)

    @Slot()
    def __onRunFinished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__state == State.Processing
        assert self.__pendingTask is not None
        assert self.sender() is self.__pendingTask.watcher
        assert self.__pendingTask.future.done()
        task = self.__pendingTask
        self.__pendingTask = None

        try:
            image_meta = task.future.result()
        except Exception as err:
            sys.excepthook(*sys.exc_info())
            state = State.Error
            image_meta = []
            self.error(traceback.format_exc())
        else:
            state = State.Done
            self.error()

        categories = {}

        for imeta in image_meta:
            # derive categories from the path relative to the starting dir
            dirname = os.path.dirname(imeta.path)
            relpath = os.path.relpath(dirname, task.startdir)
            categories[dirname] = relpath

        self._imageMeta = image_meta
        self._imageCategories = categories

        self.__setRuntimeState(state)
        self.commit()

    def cancel(self):
        """
        Cancel current pending task (if any).
        """
        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            self.__pendingTask.cancel()
            self.__pendingTask = None
            self.__setRuntimeState(State.Cancelled)

    @Slot(object)
    def __onReportProgress(self, arg):
        # report on scan progress from a worker thread
        # arg must be a namespace(count: int, lastpath: str)
        assert QThread.currentThread() is self.thread()
        if self.__state == State.Processing:
            self.pathlabel.setText(prettyfypath(arg.lastpath))

    def commit(self):
        """
        Create and commit a Table from the collected image meta data.
        """
        if self._imageMeta:
            categories = self._imageCategories
            if len(categories) > 1:
                cat_var = Orange.data.DiscreteVariable(
                    "category", values=list(sorted(categories.values()))
                )
            else:
                cat_var = None
            # Image name (file basename without the extension)
            imagename_var = Orange.data.StringVariable("image name")
            # Full fs path
            image_var = Orange.data.StringVariable("image")
            image_var.attributes["type"] = "image"
            # file size/width/height
            size_var = Orange.data.ContinuousVariable(
                "size", number_of_decimals=0)
            width_var = Orange.data.ContinuousVariable(
                "width", number_of_decimals=0)
            height_var = Orange.data.ContinuousVariable(
                "height", number_of_decimals=0)
            domain = Orange.data.Domain(
                [], [cat_var] if cat_var is not None else [],
                [imagename_var, image_var, size_var, width_var, height_var]
            )
            cat_data = []
            meta_data = []

            for imgmeta in self._imageMeta:
                if imgmeta.isvalid:
                    if cat_var is not None:
                        category = categories.get(os.path.dirname(imgmeta.path))
                        cat_data.append([cat_var.to_val(category)])
                    else:
                        cat_data.append([])
                    basename = os.path.basename(imgmeta.path)
                    imgname, _ = os.path.splitext(basename)

                    meta_data.append(
                        [imgname, imgmeta.path, imgmeta.size,
                         imgmeta.width, imgmeta.height]
                    )

            cat_data = numpy.array(cat_data, dtype=float)
            meta_data = numpy.array(meta_data, dtype=object)
            table = Orange.data.Table.from_numpy(
                domain, numpy.empty((len(cat_data), 0), dtype=float),
                cat_data, meta_data
            )
        else:
            table = None

        self.send("Data", table)

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=True)
Ejemplo n.º 13
0
class OWGeneNetwork(widget.OWWidget):
    name = "Gene Network"
    description = "Extract a gene network for a set of genes."
    icon = "../widgets/icons/GeneNetwork.svg"

    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Network", network.Graph)]

    settingsHandler = settings.DomainContextHandler()

    taxid = settings.Setting("9606")
    gene_var_index = settings.ContextSetting(-1)
    use_attr_names = settings.ContextSetting(False)

    network_source = settings.Setting(1)
    include_neighborhood = settings.Setting(True)
    min_score = settings.Setting(0.9)

    want_main_area = False

    def __init__(self, parent=None):
        super().__init__(parent)

        self.taxids = taxonomy.common_taxids()
        self.current_taxid_index = self.taxids.index(self.taxid)

        self.data = None
        self.geneinfo = None
        self.nettask = None
        self._invalidated = False

        box = gui.widgetBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, "No data on input\n")

        box = gui.widgetBox(self.controlArea, "Organism")
        self.organism_cb = gui.comboBox(
            box, self, "current_taxid_index",
            items=map(taxonomy.name, self.taxids),
            callback=self._update_organism
        )
        box = gui.widgetBox(self.controlArea, "Genes")
        self.genes_cb = gui.comboBox(
            box, self, "gene_var_index", callback=self._update_query_genes
        )
        self.varmodel = itemmodels.VariableListModel()
        self.genes_cb.setModel(self.varmodel)

        gui.checkBox(
            box, self, "use_attr_names",
            "Use attribute names",
            callback=self._update_query_genes
        )

        box = gui.widgetBox(self.controlArea, "Network")
        gui.comboBox(
            box, self, "network_source",
            items=[s.name for s in SOURCES],
            callback=self._on_source_db_changed
        )
        gui.checkBox(
            box, self, "include_neighborhood",
            "Include immediate gene neighbors",
            callback=self.invalidate
        )
        self.score_spin = gui.doubleSpin(
            box, self, "min_score", 0.0, 1.0, step=0.001,
            label="Minimal edge score",
            callback=self.invalidate
        )
        self.score_spin.setEnabled(SOURCES[self.network_source].score_filter)

        box = gui.widgetBox(self.controlArea, "Commit")
        gui.button(box, self, "Retrieve", callback=self.commit, default=True)

        self.setSizePolicy(QtGui.QSizePolicy.Fixed, QtGui.QSizePolicy.Fixed)
        self.layout().setSizeConstraint(QtGui.QLayout.SetFixedSize)

        self.executor = ThreadExecutor()

    def set_data(self, data):
        self.closeContext()
        self.data = data
        if data is not None:
            self.varmodel[:] = string_variables(data.domain)
            taxid = data_hints.get_hint(data, "taxid", default=self.taxid)

            if taxid in self.taxids:
                self.set_organism(self.taxids.index(taxid))

            self.use_attr_names = data_hints.get_hint(
                data, "genesinrows", default=self.use_attr_names
            )

            if not (0 <= self.gene_var_index < len(self.varmodel)):
                self.gene_var_index = len(self.varmodel) - 1

            self.openContext(data)
            self.invalidate()
            self.commit()
        else:
            self.varmodel[:] = []
            self.send("Network", None)

    def set_source_db(self, dbindex):
        self.network_source = dbindex
        self.invalidate()

    def set_organism(self, index):
        self.current_taxid_index = index
        self.taxid = self.taxids[index]
        self.invalidate()

    def set_gene_var(self, index):
        self.gene_var_index = index
        self.invalidate()

    def query_genes(self):
        if self.use_attr_names:
            if self.data is not None:
                return [var.name for var in self.data.domain.attributes]
            else:
                return []
        elif self.gene_var_index >= 0:
            var = self.varmodel[self.gene_var_index]
            genes = [str(inst[var]) for inst in self.data
                     if not compat.isunknown(inst[var])]
            return list(unique(genes))
        else:
            return []

    def invalidate(self):
        self._invalidated = True

        if self.nettask is not None:
            self.nettask.finished.disconnect(self._on_result_ready)
            self.nettask.future().cancel()
            self.nettask = None

    @Slot()
    def advance(self):
        self.progressBarValue = (self.progressBarValue + 1) % 100

    @Slot(float)
    def set_progress(self, value):
        self.progressBarSet(value, processEvents=None)

    def commit(self):
        include_neighborhood = self.include_neighborhood
        query_genes = self.query_genes()
        source = SOURCES[self.network_source]
        if source.score_filter:
            min_score = self.min_score
            assert source.name == "STRING"
            min_score = min_score * 1000
        else:
            min_score = None

        taxid = self.taxid
        progress = methodinvoke(self, "advance")
        if self.geneinfo is None:
            self.geneinfo = self.executor.submit(
                fetch_ncbi_geneinfo, taxid, progress
            )

        geneinfo_f = self.geneinfo
        taxmap = source.tax_mapping
        db_taxid = taxmap.get(taxid, taxid)

        if db_taxid is None:
            raise ValueError("invalid taxid for this network")

        def fetch_network():
            geneinfo = geneinfo_f.result()
            ppidb = fetch_ppidb(source, db_taxid, progress)
            return get_gene_network(ppidb, geneinfo, db_taxid, query_genes,
                                    include_neighborhood=include_neighborhood,
                                    min_score=min_score,
                                    progress=methodinvoke(self, "set_progress", (float,)))

        self.nettask = Task(function=fetch_network)
        self.nettask.finished.connect(self._on_result_ready)

        self.executor.submit(self.nettask)

        self.setBlocking(True)
        self.setEnabled(False)
        self.progressBarInit()
        self._invalidated = False
        self._update_info()

    @Slot(object)
    def _on_result_ready(self,):
        self.progressBarFinished()
        self.setBlocking(False)
        self.setEnabled(True)
        net = self.nettask.result()
        self._update_info()
        self.send("Network", net)

    def _on_source_db_changed(self):
        source = SOURCES[self.network_source]
        self.score_spin.setEnabled(source.score_filter)
        self.invalidate()

    def _update_organism(self):
        self.taxid = self.taxids[self.current_taxid_index]
        if self.geneinfo is not None:
            self.geneinfo.cancel()
        self.geneinfo = None
        self.invalidate()

    def _update_query_genes(self):
        self.invalidate()

    def _update_info(self):
        if self.data is None:
            self.info.setText("No data on input\n")
        else:
            names = self.query_genes()
            lines = ["%i unique genes on input" % len(set(names))]
            if self.nettask is not None:
                if not self.nettask.future().done():
                    lines.append("Retrieving ...")
                else:
                    net = self.nettask.result()
                    lines.append("%i nodes %i edges" %
                                 (len(net.nodes()), len(net.edges())))
            else:
                lines.append("")

            self.info.setText("\n".join(lines))
Ejemplo n.º 14
0
class OWGeneSets(OWWidget):
    name = "Gene Sets"
    description = ""
    icon = "icons/OWGeneSets.svg"
    priority = 9
    want_main_area = True

    COUNT, GENES, CATEGORY, TERM = range(4)
    DATA_HEADER_LABELS = ["Count", 'Genes In Set', 'Category', 'Term']

    organism = Setting(None, schema_only=True)
    stored_gene_sets_selection = Setting([], schema_only=True)
    selected_rows = Setting([], schema_only=True)
    custom_gene_set_indicator = Setting(None, schema_only=True)

    min_count = Setting(5)
    use_min_count = Setting(True)
    auto_commit = Setting(True)

    class Inputs:
        genes = Input("Data", Table)
        custom_sets = Input('Custom Gene Sets', Table)

    class Outputs:
        matched_genes = Output("Matched Genes", Table)

    class Information(OWWidget.Information):
        pass

    class Warning(OWWidget.Warning):
        all_sets_filtered = Msg('All sets were filtered out.')

    class Error(OWWidget.Error):
        organism_mismatch = Msg('Organism in input data and custom gene sets does not match')
        missing_annotation = Msg(ERROR_ON_MISSING_ANNOTATION)
        missing_gene_id = Msg(ERROR_ON_MISSING_GENE_ID)
        missing_tax_id = Msg(ERROR_ON_MISSING_TAX_ID)
        cant_reach_host = Msg("Host orange.biolab.si is unreachable.")
        cant_load_organisms = Msg("No available organisms, please check your connection.")

    def __init__(self):
        super().__init__()

        # commit
        self.commit_button = None

        # progress bar
        self.progress_bar = None
        self.progress_bar_iterations = None

        # data
        self.input_data = None
        self.input_genes = []
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gene_id_column = None

        # custom gene sets
        self.custom_data = None
        self.feature_model = DomainModel(valid_types=(DiscreteVariable, StringVariable))
        self.custom_gs_col_box = None
        self.gs_label_combobox = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.num_of_custom_sets = None

        # Gene Sets widget
        self.gs_widget = None

        # info box
        self.input_info = None
        self.num_of_sel_genes = 0

        # filter
        self.line_edit_filter = None
        self.search_pattern = ''
        self.organism_select_combobox = None

        # data model view
        self.data_view = None
        self.data_model = None

        # gene matcher NCBI
        self.gene_matcher = None

        # filter proxy model
        self.filter_proxy_model = None

        # hierarchy widget
        self.hierarchy_widget = None
        self.hierarchy_state = None

        # spinbox
        self.spin_widget = None

        # threads
        self.threadpool = QThreadPool(self)
        self.workers = None

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # gui
        self.setup_gui()

    def __reset_widget_state(self):
        self.update_info_box()
        # clear data view
        self.init_item_model()
        # reset filters
        self.setup_filter_model()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._init_gene_sets_finished)
            self._task = None

    @Slot()
    def progress_advance(self):
        # GUI should be updated in main thread. That's why we are calling advance method here
        if self.progress_bar:
            self.progress_bar.advance()

    def __get_input_genes(self):
        self.input_genes = []

        if self.use_attr_names:
            for variable in self.input_data.domain.attributes:
                self.input_genes.append(str(variable.attributes.get(self.gene_id_attribute, '?')))
        else:
            genes, _ = self.input_data.get_column_view(self.gene_id_column)
            self.input_genes = [str(g) for g in genes]

    def handle_custom_gene_sets(self, select_customs_flag=False):
        if self.custom_gene_set_indicator:
            if self.custom_data is not None and self.custom_gene_id_column is not None:

                if self.__check_organism_mismatch():
                    # self.gs_label_combobox.setDisabled(True)
                    self.Error.organism_mismatch()
                    self.gs_widget.update_gs_hierarchy()
                    return

                if isinstance(self.custom_gene_set_indicator, DiscreteVariable):
                    labels = self.custom_gene_set_indicator.values
                    gene_sets_names = [
                        labels[int(idx)] for idx in self.custom_data.get_column_view(self.custom_gene_set_indicator)[0]
                    ]
                else:
                    gene_sets_names, _ = self.custom_data.get_column_view(self.custom_gene_set_indicator)

                self.num_of_custom_sets = len(set(gene_sets_names))
                gene_names, _ = self.custom_data.get_column_view(self.custom_gene_id_column)
                hierarchy_title = (self.custom_data.name if self.custom_data.name else 'Custom sets',)
                try:
                    self.gs_widget.add_custom_sets(
                        gene_sets_names,
                        gene_names,
                        hierarchy_title=hierarchy_title,
                        select_customs_flag=select_customs_flag,
                    )
                except geneset.GeneSetException:
                    pass
                # self.gs_label_combobox.setDisabled(False)
            else:
                self.gs_widget.update_gs_hierarchy()

        self.update_info_box()

    def update_tree_view(self):
        self.init_gene_sets()

    def invalidate(self):
        # clear
        self.__reset_widget_state()
        self.update_info_box()

        if self.input_data is not None:
            # setup
            self.__get_input_genes()
            self.update_tree_view()

    def __check_organism_mismatch(self):
        """ Check if organisms from different inputs match.

        :return: True if there is a mismatch
        """
        if self.tax_id is not None and self.custom_tax_id is not None:
            return self.tax_id != self.custom_tax_id
        return False

    def __get_reference_genes(self):
        self.reference_genes = []

        if self.reference_attr_names:
            for variable in self.reference_data.domain.attributes:
                self.reference_genes.append(str(variable.attributes.get(self.reference_gene_id_attribute, '?')))
        else:
            genes, _ = self.reference_data.get_column_view(self.reference_gene_id_column)
            self.reference_genes = [str(g) for g in genes]

    @Inputs.custom_sets
    def handle_custom_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        self.custom_data = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.feature_model.set_domain(None)

        if data:
            self.custom_data = data
            self.feature_model.set_domain(self.custom_data.domain)
            self.custom_tax_id = str(self.custom_data.attributes.get(TAX_ID, None))
            self.custom_use_attr_names = self.custom_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.custom_gene_id_attribute = self.custom_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.custom_gene_id_column = self.custom_data.attributes.get(GENE_ID_COLUMN, None)

            if self.gs_label_combobox is None:
                self.gs_label_combobox = comboBox(
                    self.custom_gs_col_box,
                    self,
                    "custom_gene_set_indicator",
                    sendSelectedValue=True,
                    model=self.feature_model,
                    callback=self.on_gene_set_indicator_changed,
                )
            self.custom_gs_col_box.show()

            if self.custom_gene_set_indicator in self.feature_model:
                index = self.feature_model.indexOf(self.custom_gene_set_indicator)
                self.custom_gene_set_indicator = self.feature_model[index]
            else:
                self.custom_gene_set_indicator = self.feature_model[0]
        else:
            self.custom_gs_col_box.hide()

        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets(select_customs_flag=self.custom_gene_set_indicator is not None)
        self.invalidate()

    @Inputs.genes
    def handle_genes_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        # clear output
        self.Outputs.matched_genes.send(None)
        # clear input values
        self.input_genes = []
        self.input_data = None
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gs_widget.clear()
        self.gs_widget.clear_gene_sets()
        self.update_info_box()

        if data:
            self.input_data = data
            self.tax_id = str(self.input_data.attributes.get(TAX_ID, None))
            self.use_attr_names = self.input_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.gene_id_attribute = self.input_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.gene_id_column = self.input_data.attributes.get(GENE_ID_COLUMN, None)
            self.update_info_box()

            if not (
                self.use_attr_names is not None and ((self.gene_id_attribute is None) ^ (self.gene_id_column is None))
            ):

                if self.tax_id is None:
                    self.Error.missing_annotation()
                    return

                self.Error.missing_gene_id()
                return

            elif self.tax_id is None:
                self.Error.missing_tax_id()
                return

            if self.__check_organism_mismatch():
                self.Error.organism_mismatch()
                return

            self.gs_widget.load_gene_sets(self.tax_id)

            # if input data change, we need to refresh custom sets
            if self.custom_data:
                self.gs_widget.clear_custom_sets()
                self.handle_custom_gene_sets()

            self.invalidate()

    def update_info_box(self):
        info_string = ''
        if self.input_genes:
            info_string += '{} unique gene names on input.\n'.format(len(self.input_genes))
            info_string += '{} genes on output.\n'.format(self.num_of_sel_genes)
        else:
            if self.input_data:
                if not any([self.gene_id_column, self.gene_id_attribute]):
                    info_string += 'Input data with incorrect meta data.\nUse Gene Name Matcher widget.'
            else:
                info_string += 'No data on input.\n'

        if self.custom_data:
            info_string += '{} marker genes in {} sets\n'.format(self.custom_data.X.shape[0], self.num_of_custom_sets)

        self.input_info.setText(info_string)

    def create_partial(self):
        return partial(
            self.set_items,
            self.gs_widget.gs_object,
            self.stored_gene_sets_selection,
            set(self.input_genes),
            self.callback,
        )

    def callback(self):
        if self._task.cancelled:
            raise KeyboardInterrupt()
        if self.progress_bar:
            methodinvoke(self, "progress_advance")()

    def init_gene_sets(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task()
        self.init_item_model()

        # save setting on selected hierarchies
        self.stored_gene_sets_selection = self.gs_widget.get_hierarchies(only_selected=True)

        f = self.create_partial()

        progress_iterations = sum(
            (
                len(g_set)
                for hier, g_set in self.gs_widget.gs_object.map_hierarchy_to_sets().items()
                if hier in self.stored_gene_sets_selection
            )
        )

        self.progress_bar = ProgressBar(self, iterations=progress_iterations)

        self._task.future = self._executor.submit(f)

        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._init_gene_sets_finished)

    @Slot(concurrent.futures.Future)
    def _init_gene_sets_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progress_bar.finish()
        self.setStatusMessage('')

        try:
            results = f.result()  # type: list
            [self.data_model.appendRow(model_item) for model_item in results]
            self.filter_proxy_model.setSourceModel(self.data_model)
            self.data_view.selectionModel().selectionChanged.connect(self.commit)
            self.filter_data_view()
            self.set_selection()
            self.update_info_box()
        except Exception as ex:
            print(ex)

    def create_filters(self):
        search_term = self.search_pattern.lower().strip().split()

        filters = [
            FilterProxyModel.Filter(
                self.TERM, Qt.DisplayRole, lambda value: all(fs in value.lower() for fs in search_term)
            )
        ]

        if self.use_min_count:
            filters.append(FilterProxyModel.Filter(self.COUNT, Qt.DisplayRole, lambda value: value >= self.min_count))

        return filters

    def filter_data_view(self):
        filter_proxy = self.filter_proxy_model  # type: FilterProxyModel
        model = filter_proxy.sourceModel()  # type: QStandardItemModel

        if isinstance(model, QStandardItemModel):

            # apply filtering rules
            filter_proxy.set_filters(self.create_filters())

            if model.rowCount() and not filter_proxy.rowCount():
                self.Warning.all_sets_filtered()
            else:
                self.Warning.clear()

    def set_selection(self):
        if len(self.selected_rows):
            view = self.data_view
            model = self.data_model

            row_model_indexes = [model.indexFromItem(model.item(i)) for i in self.selected_rows]
            proxy_rows = [self.filter_proxy_model.mapFromSource(i).row() for i in row_model_indexes]

            if model.rowCount() <= self.selected_rows[-1]:
                return

            header_count = view.header().count() - 1
            selection = QItemSelection()

            for row_index in proxy_rows:
                selection.append(
                    QItemSelectionRange(
                        self.filter_proxy_model.index(row_index, 0),
                        self.filter_proxy_model.index(row_index, header_count),
                    )
                )

            view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

    def commit(self):
        selection_model = self.data_view.selectionModel()

        if selection_model:
            selection = selection_model.selectedRows(self.COUNT)
            self.selected_rows = [self.filter_proxy_model.mapToSource(sel).row() for sel in selection]

            if selection and self.input_genes:
                genes = [model_index.data(Qt.UserRole) for model_index in selection]
                output_genes = [gene_name for gene_name in list(set.union(*genes))]
                self.num_of_sel_genes = len(output_genes)
                self.update_info_box()

                if self.use_attr_names:
                    selected = [
                        column
                        for column in self.input_data.domain.attributes
                        if self.gene_id_attribute in column.attributes
                        and str(column.attributes[self.gene_id_attribute]) in output_genes
                    ]

                    domain = Domain(selected, self.input_data.domain.class_vars, self.input_data.domain.metas)
                    new_data = self.input_data.from_table(domain, self.input_data)
                    self.Outputs.matched_genes.send(new_data)

                else:
                    # create filter from selected column for genes
                    only_known = table_filter.FilterStringList(self.gene_id_column, output_genes)
                    # apply filter to the data
                    data_table = table_filter.Values([only_known])(self.input_data)

                    self.Outputs.matched_genes.send(data_table)

    def assign_delegates(self):
        self.data_view.setItemDelegateForColumn(self.GENES, NumericalColumnDelegate(self))

        self.data_view.setItemDelegateForColumn(self.COUNT, NumericalColumnDelegate(self))

    def setup_filter_model(self):
        self.filter_proxy_model = FilterProxyModel()
        self.filter_proxy_model.setFilterKeyColumn(self.TERM)
        self.data_view.setModel(self.filter_proxy_model)

    def setup_filter_area(self):
        h_layout = QHBoxLayout()
        h_layout.setSpacing(100)
        h_widget = widgetBox(self.mainArea, orientation=h_layout)

        spin(
            h_widget,
            self,
            'min_count',
            0,
            1000,
            label='Count',
            tooltip='Minimum genes count',
            checked='use_min_count',
            callback=self.filter_data_view,
            callbackOnReturn=True,
            checkCallback=self.filter_data_view,
        )

        self.line_edit_filter = lineEdit(h_widget, self, 'search_pattern')
        self.line_edit_filter.setPlaceholderText('Filter gene sets ...')
        self.line_edit_filter.textChanged.connect(self.filter_data_view)

    def on_gene_set_indicator_changed(self):
        # self._handle_future_model()
        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets()
        self.invalidate()

    def setup_control_area(self):
        # Control area
        self.input_info = widgetLabel(widgetBox(self.controlArea, "Info", addSpace=True), 'No data on input.\n')
        self.custom_gs_col_box = box = vBox(self.controlArea, 'Custom Gene Set Term Column')
        box.hide()

        gene_sets_box = widgetBox(self.controlArea, "Gene Sets")
        self.gs_widget = GeneSetsSelection(gene_sets_box, self, 'stored_gene_sets_selection')
        self.gs_widget.hierarchy_tree_widget.itemClicked.connect(self.update_tree_view)

        self.commit_button = auto_commit(self.controlArea, self, "auto_commit", "&Commit", box=False)

    def setup_gui(self):
        # control area
        self.setup_control_area()

        # main area
        self.data_view = QTreeView()
        self.setup_filter_model()
        self.setup_filter_area()
        self.data_view.setAlternatingRowColors(True)
        self.data_view.sortByColumn(self.COUNT, Qt.DescendingOrder)
        self.data_view.setSortingEnabled(True)
        self.data_view.setSelectionMode(QTreeView.ExtendedSelection)
        self.data_view.setEditTriggers(QTreeView.NoEditTriggers)
        self.data_view.viewport().setMouseTracking(False)
        self.data_view.setItemDelegateForColumn(self.TERM, LinkStyledItemDelegate(self.data_view))

        self.mainArea.layout().addWidget(self.data_view)

        self.data_view.header().setSectionResizeMode(QHeaderView.ResizeToContents)
        self.assign_delegates()

    @staticmethod
    def set_items(gene_sets, sets_to_display, genes, callback):
        model_items = []
        if not genes:
            return

        for gene_set in sorted(gene_sets):
            if gene_set.hierarchy not in sets_to_display:
                continue

            callback()

            matched_set = gene_set.genes & genes
            if len(matched_set) > 0:
                category_column = QStandardItem()
                term_column = QStandardItem()
                count_column = QStandardItem()
                genes_column = QStandardItem()

                category_column.setData(", ".join(gene_set.hierarchy), Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.ToolTipRole)
                term_column.setData(gene_set.link, LinkRole)
                term_column.setForeground(QColor(Qt.blue))

                count_column.setData(matched_set, Qt.UserRole)
                count_column.setData(len(matched_set), Qt.DisplayRole)

                genes_column.setData(len(gene_set.genes), Qt.DisplayRole)
                genes_column.setData(
                    set(gene_set.genes), Qt.UserRole
                )  # store genes to get then on output on selection

                model_items.append([count_column, genes_column, category_column, term_column])

        return model_items

    def init_item_model(self):
        if self.data_model:
            self.data_model.clear()
            self.setup_filter_model()
        else:
            self.data_model = QStandardItemModel()

        self.data_model.setSortRole(Qt.UserRole)
        self.data_model.setHorizontalHeaderLabels(self.DATA_HEADER_LABELS)

    def sizeHint(self):
        return QSize(1280, 960)
Ejemplo n.º 15
0
class OWGEODatasets(OWWidget):
    name = "GEO Data Sets"
    description = "Access to Gene Expression Omnibus data sets."
    icon = "icons/OWGEODatasets.svg"
    priority = 2

    inputs = []
    outputs = [("Expression Data", Table)]

    settingsList = [
        "outputRows", "mergeSpots", "gdsSelectionStates", "splitterSettings",
        "currentGds", "autoCommit", "datasetNames"
    ]

    outputRows = Setting(True)
    mergeSpots = Setting(True)
    gdsSelectionStates = Setting({})
    currentGds = Setting(None)
    datasetNames = Setting({})
    splitterSettings = Setting((
        b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xea\x00\x00\x00\xd7\x01\x00\x00\x00\x07\x01\x00\x00\x00\x02',
        b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xb5\x00\x00\x02\x10\x01\x00\x00\x00\x07\x01\x00\x00\x00\x01'
    ))

    autoCommit = Setting(False)

    def __init__(self, parent=None, signalManager=None, name=" GEO Data Sets"):
        OWWidget.__init__(self, parent, signalManager, name)

        self.selectionChanged = False
        self.filterString = ""
        self.datasetName = ""

        ## GUI
        box = gui.widgetBox(self.controlArea, "Info", addSpace=True)
        self.infoBox = gui.widgetLabel(box, "Initializing\n\n")

        box = gui.widgetBox(self.controlArea, "Output", addSpace=True)
        gui.radioButtonsInBox(box,
                              self,
                              "outputRows",
                              ["Genes in rows", "Samples in rows"],
                              "Rows",
                              callback=self.commitIf)

        gui.checkBox(box,
                     self,
                     "mergeSpots",
                     "Merge spots of same gene",
                     callback=self.commitIf)

        gui.separator(box)
        self.nameEdit = gui.lineEdit(
            box,
            self,
            "datasetName",
            "Data set name",
            tooltip="Override the default output data set name",
            callback=self.onNameEdited)
        self.nameEdit.setPlaceholderText("")

        if sys.version_info < (3, ):
            box = gui.widgetBox(self.controlArea, "Commit", addSpace=True)
            self.commitButton = gui.button(box,
                                           self,
                                           "Commit",
                                           callback=self.commit)
            cb = gui.checkBox(box, self, "autoCommit", "Commit on any change")
            gui.setStopper(self, self.commitButton, cb, "selectionChanged",
                           self.commit)
        else:
            gui.auto_commit(self.controlArea,
                            self,
                            "autoCommit",
                            "Commit",
                            box="Commit")
            self.commitIf = self.commit

        gui.rubber(self.controlArea)

        gui.widgetLabel(self.mainArea, "Filter")
        self.filterLineEdit = QLineEdit(textChanged=self.filter)
        self.completer = TokenListCompleter(self,
                                            caseSensitivity=Qt.CaseInsensitive)
        self.filterLineEdit.setCompleter(self.completer)

        self.mainArea.layout().addWidget(self.filterLineEdit)

        splitter = QSplitter(Qt.Vertical, self.mainArea)
        self.mainArea.layout().addWidget(splitter)
        self.treeWidget = QTreeView(splitter)

        self.treeWidget.setSelectionMode(QTreeView.SingleSelection)
        self.treeWidget.setRootIsDecorated(False)
        self.treeWidget.setSortingEnabled(True)
        self.treeWidget.setAlternatingRowColors(True)
        self.treeWidget.setUniformRowHeights(True)
        self.treeWidget.setEditTriggers(QTreeView.NoEditTriggers)

        linkdelegate = gui.LinkStyledItemDelegate(self.treeWidget)
        self.treeWidget.setItemDelegateForColumn(1, linkdelegate)
        self.treeWidget.setItemDelegateForColumn(8, linkdelegate)
        self.treeWidget.setItemDelegateForColumn(
            0, gui.IndicatorItemDelegate(self.treeWidget, role=Qt.DisplayRole))

        proxyModel = MySortFilterProxyModel(self.treeWidget)
        self.treeWidget.setModel(proxyModel)
        self.treeWidget.selectionModel().selectionChanged.connect(
            self.updateSelection)
        self.treeWidget.viewport().setMouseTracking(True)

        splitterH = QSplitter(Qt.Horizontal, splitter)

        box = gui.widgetBox(splitterH, "Description")
        self.infoGDS = gui.widgetLabel(box, "")
        self.infoGDS.setWordWrap(True)
        gui.rubber(box)

        box = gui.widgetBox(splitterH, "Sample Annotations")
        self.annotationsTree = QTreeWidget(box)
        self.annotationsTree.setHeaderLabels(
            ["Type (Sample annotations)", "Sample count"])
        self.annotationsTree.setRootIsDecorated(True)
        box.layout().addWidget(self.annotationsTree)
        self.annotationsTree.itemChanged.connect(
            self.annotationSelectionChanged)
        self._annotationsUpdating = False
        self.splitters = splitter, splitterH

        for sp, setting in zip(self.splitters, self.splitterSettings):
            sp.splitterMoved.connect(self.splitterMoved)
            sp.restoreState(setting)

        self.searchKeys = [
            "dataset_id", "title", "platform_organism", "description"
        ]

        self.gds = []
        self.gds_info = None

        self.resize(1000, 600)

        self.setBlocking(True)
        self.setEnabled(False)
        self.progressBarInit()

        self._executor = ThreadExecutor()

        func = partial(get_gds_model,
                       methodinvoke(self, "_setProgress", (float, )))
        self._inittask = Task(function=func)
        self._inittask.finished.connect(self._initializemodel)
        self._executor.submit(self._inittask)

        self._datatask = None

    @Slot(float)
    def _setProgress(self, value):
        self.progressBarValue = value

    def _initializemodel(self):
        assert self.thread() is QThread.currentThread()
        model, self.gds_info, self.gds = self._inittask.result()
        model.setParent(self)

        proxy = self.treeWidget.model()
        proxy.setFilterKeyColumn(0)
        proxy.setFilterRole(TextFilterRole)
        proxy.setFilterCaseSensitivity(False)
        proxy.setFilterFixedString(self.filterString)

        proxy.setSourceModel(model)
        proxy.sort(0, Qt.DescendingOrder)

        self.progressBarFinished()
        self.setBlocking(False)
        self.setEnabled(True)

        filter_items = " ".join(gds[key] for gds in self.gds
                                for key in self.searchKeys)
        tr_chars = ",.:;!?(){}[]_-+\\|/%#@$^&*<>~`"
        tr_table = str.maketrans(tr_chars, " " * len(tr_chars))
        filter_items = filter_items.translate(tr_table)

        filter_items = sorted(set(filter_items.split(" ")))
        filter_items = [item for item in filter_items if len(item) > 3]

        self.completer.setTokenList(filter_items)

        if self.currentGds:
            current_id = self.currentGds["dataset_id"]
            gdss = [(i, proxy.data(proxy.index(i, 1), Qt.DisplayRole))
                    for i in range(proxy.rowCount())]
            current = [i for i, data in gdss if data and data == current_id]
            if current:
                current_index = proxy.index(current[0], 0)
                self.treeWidget.selectionModel().select(
                    current_index,
                    QItemSelectionModel.Select | QItemSelectionModel.Rows)
                self.treeWidget.scrollTo(current_index,
                                         QTreeView.PositionAtCenter)

        for i in range(8):
            self.treeWidget.resizeColumnToContents(i)

        self.treeWidget.setColumnWidth(
            1, min(self.treeWidget.columnWidth(1), 300))
        self.treeWidget.setColumnWidth(
            2, min(self.treeWidget.columnWidth(2), 200))

        self.updateInfo()

    def updateInfo(self):
        gds_info = self.gds_info
        text = ("%i datasets\n%i datasets cached\n" %
                (len(gds_info),
                 len(glob.glob(serverfiles.localpath("GEO") + "/GDS*"))))
        filtered = self.treeWidget.model().rowCount()
        if len(self.gds) != filtered:
            text += ("%i after filtering") % filtered
        self.infoBox.setText(text)

    def updateSelection(self, *args):
        current = self.treeWidget.selectedIndexes()
        mapToSource = self.treeWidget.model().mapToSource
        current = [mapToSource(index).row() for index in current]
        if current:
            self.currentGds = self.gds[current[0]]
            self.setAnnotations(self.currentGds)
            self.infoGDS.setText(self.currentGds.get("description", ""))
            self.nameEdit.setPlaceholderText(self.currentGds["title"])
            self.datasetName = \
                self.datasetNames.get(self.currentGds["dataset_id"], "")
        else:
            self.currentGds = None
            self.nameEdit.setPlaceholderText("")
            self.datasetName = ""

        self.commitIf()

    def setAnnotations(self, gds):
        self._annotationsUpdating = True
        self.annotationsTree.clear()

        annotations = defaultdict(set)
        subsetscount = {}
        for desc in gds["subsets"]:
            annotations[desc["type"]].add(desc["description"])
            subsetscount[desc["description"]] = str(len(desc["sample_id"]))

        for type, subsets in annotations.items():
            key = (gds["dataset_id"], type)
            subsetItem = QTreeWidgetItem(self.annotationsTree, [type])
            subsetItem.setFlags(subsetItem.flags() | Qt.ItemIsUserCheckable
                                | Qt.ItemIsTristate)
            subsetItem.setCheckState(
                0, self.gdsSelectionStates.get(key, Qt.Checked))
            subsetItem.key = key
            for subset in subsets:
                key = (gds["dataset_id"], type, subset)
                item = QTreeWidgetItem(
                    subsetItem, [subset, subsetscount.get(subset, "")])
                item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
                item.setCheckState(
                    0, self.gdsSelectionStates.get(key, Qt.Checked))
                item.key = key
        self._annotationsUpdating = False
        self.annotationsTree.expandAll()
        for i in range(self.annotationsTree.columnCount()):
            self.annotationsTree.resizeColumnToContents(i)

    def annotationSelectionChanged(self, item, column):
        if self._annotationsUpdating:
            return
        for i in range(self.annotationsTree.topLevelItemCount()):
            item = self.annotationsTree.topLevelItem(i)
            self.gdsSelectionStates[item.key] = item.checkState(0)
            for j in range(item.childCount()):
                child = item.child(j)
                self.gdsSelectionStates[child.key] = child.checkState(0)

    def filter(self):
        filter_string = self.filterLineEdit.text()
        proxyModel = self.treeWidget.model()
        if proxyModel:
            strings = filter_string.lower().strip().split()
            proxyModel.setFilterFixedStrings(strings)
            self.updateInfo()

    def selectedSamples(self):
        """
        Return the currently selected sample annotations.

        The return value is a list of selected (sample type, sample value)
        tuples.

        .. note:: if some Sample annotation type has no selected values.
                  this method will return all values for it.

        """
        samples = []
        unused_types = []
        used_types = []
        for stype in childiter(self.annotationsTree.invisibleRootItem()):
            selected_values = []
            all_values = []
            for sval in childiter(stype):
                value = (str(stype.text(0)), str(sval.text(0)))
                if self.gdsSelectionStates.get(sval.key, True):
                    selected_values.append(value)
                all_values.append(value)
            if selected_values:
                samples.extend(selected_values)
                used_types.append(str(stype.text(0)))
            else:
                # If no sample of sample type is selected we don't filter
                # on it.
                samples.extend(all_values)
                unused_types.append(str(stype.text(0)))

        return samples, used_types

    def commitIf(self):
        if self.autoCommit:
            self.commit()
        else:
            self.selectionChanged = True

    @Slot(int, int)
    def progressCompleted(self, value, total):
        if total > 0:
            self.progressBarSet(100. * value / total, processEvents=False)
        else:
            pass
            # TODO: report 'indeterminate progress'

    def commit(self):
        if self.currentGds:
            self.error(0)
            sample_type = None
            self.progressBarInit(processEvents=None)

            _, groups = self.selectedSamples()
            if len(groups) == 1 and self.outputRows:
                sample_type = groups[0]

            self.setEnabled(False)
            self.setBlocking(True)

            progress = methodinvoke(self, "progressCompleted", (int, int))

            def get_data(gds_id, report_genes, transpose, sample_type, title):
                gds_ensure_downloaded(gds_id, progress)
                gds = GDS(gds_id)
                data = gds.get_data(report_genes=report_genes,
                                    transpose=transpose,
                                    sample_type=sample_type)
                data.name = title
                return data

            get_data = partial(get_data,
                               self.currentGds["dataset_id"],
                               report_genes=self.mergeSpots,
                               transpose=self.outputRows,
                               sample_type=sample_type,
                               title=self.datasetName
                               or self.currentGds["title"])
            self._datatask = Task(function=get_data)
            self._datatask.finished.connect(self._on_dataready)
            self._executor.submit(self._datatask)

    def _on_dataready(self):
        self.setEnabled(True)
        self.setBlocking(False)
        self.progressBarFinished(processEvents=False)

        try:
            data = self._datatask.result()
        except urlrequest.URLError as error:
            self.error(0, ("Error while connecting to the NCBI ftp server! "
                           "'%s'" % error))
            sys.excepthook(type(error), error, getattr(error, "__traceback__"))
            return
        finally:
            self._datatask = None

        data_name = data.name
        samples, _ = self.selectedSamples()

        self.warning(0)
        message = None
        if self.outputRows:

            def samplesinst(ex):
                out = []
                for meta in data.domain.metas:
                    out.append((meta.name, ex[meta].value))

                if data.domain.class_var.name != 'class':
                    out.append((data.domain.class_var.name,
                                ex[data.domain.class_var].value))

                return out

            samples = set(samples)
            mask = [samples.issuperset(samplesinst(ex)) for ex in data]
            data = data[numpy.array(mask, dtype=bool)]
            if len(data) == 0:
                message = "No samples with selected sample annotations."
        else:
            samples = set(samples)
            domain = Domain([
                attr for attr in data.domain.attributes
                if samples.issuperset(attr.attributes.items())
            ], data.domain.class_var, data.domain.metas)
            #             domain.addmetas(data.domain.getmetas())

            if len(domain.attributes) == 0:
                message = "No samples with selected sample annotations."
            stypes = set(s[0] for s in samples)
            for attr in domain.attributes:
                attr.attributes = dict(
                    (key, value) for key, value in attr.attributes.items()
                    if key in stypes)
            data = Table(domain, data)

        if message is not None:
            self.warning(0, message)

        data_hints.set_hint(data, TAX_ID, self.currentGds.get("taxid", ""))
        data_hints.set_hint(data, GENE_NAME, bool(self.outputRows))

        data.name = data_name
        self.send("Expression Data", data)

        model = self.treeWidget.model().sourceModel()
        row = self.gds.index(self.currentGds)

        model.setData(model.index(row, 0), " ", Qt.DisplayRole)

        self.updateInfo()
        self.selectionChanged = False

    def splitterMoved(self, *args):
        self.splitterSettings = [
            bytes(sp.saveState()) for sp in self.splitters
        ]

    def send_report(self):
        self.report_items("GEO Dataset",
                          [("ID", self.currentGds['dataset_id']),
                           ("Title", self.currentGds['title']),
                           ("Organism", self.currentGds['sample_organism'])])
        self.report_items("Data",
                          [("Samples", self.currentGds['sample_count']),
                           ("Features", self.currentGds['feature_count']),
                           ("Genes", self.currentGds['gene_count'])])
        self.report_name("Sample annotations")
        subsets = defaultdict(list)
        for subset in self.currentGds['subsets']:
            subsets[subset['type']].append(
                (subset['description'], len(subset['sample_id'])))
        self.report_html += "<ul>"
        for type in subsets:
            self.report_html += "<b>" + type + ":</b></br>"
            for desc, count in subsets[type]:
                self.report_html += 9 * "&nbsp" + "<b>{}:</b> {}</br>".format(
                    desc, count)
        self.report_html += "</ul>"

    def onDeleteWidget(self):
        if self._inittask:
            self._inittask.future().cancel()
            self._inittask.finished.disconnect(self._initializemodel)
        if self._datatask:
            self._datatask.future().cancel()
            self._datatask.finished.disconnect(self._on_dataready)
        self._executor.shutdown(wait=False)

        super(OWGEODatasets, self).onDeleteWidget()

    def onNameEdited(self):
        if self.currentGds:
            gds_id = self.currentGds["dataset_id"]
            self.datasetNames[gds_id] = self.nameEdit.text()
            self.commitIf()
class ClusterModel(QAbstractListModel):
    def __init__(self, parent=None):
        QAbstractListModel.__init__(self)
        self.__items = []
        self.parent = parent

        self._task = None  # type: Union[Task, None]
        self._executor = ThreadExecutor()

    def add_rows(self, rows):
        self.__items = rows

    def get_rows(self):
        return self.__items

    def rowCount(self, *args, **kwargs):
        return len(self.__items)

    def data(self, model_index, role=None):
        # check if data is set
        if not self.__items:
            return QVariant()

        # return empty QVariant if model index is unknown
        if not model_index.isValid() or not (0 <= model_index.row() < len(self.__items)):
            return QVariant()

        row_obj = self.__items[model_index.row()]

        if role == Qt.DisplayRole:
            return row_obj

    @Slot(concurrent.futures.Future)
    def _end_task(self, f):
        assert self.thread() is QThread.currentThread()
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.parent.progressBarFinished()
        self.parent.filter_genes()

        try:
            f.result()
        except Exception as ex:
            raise ex

    def _score_genes(self, callback, **kwargs):
        for item in self.get_rows():
            item.cluster_scores(**kwargs)
            callback()

    @Slot(bool)
    def progress_advance(self, finish):
        # GUI should be updated in main thread. That's why wex are calling advance method here
        if self.parent.progress_bar:
            if finish:
                self.parent.progressBarFinished()
            else:
                self.parent.progress_bar.advance()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._end_task)
            self._task = None

    def score_genes(self, **kwargs):
        """ Run gene enrichment.

        :param design:
        :param data_x:
        :param rows_by_cluster:
        :param method:


        Note:
            We do not apply filter nor notify view that data is changed. This is done after filters
        """
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        progress_advance = methodinvoke(self, "progress_advance", (bool,))

        def callback():
            if self._task.cancelled:
                raise KeyboardInterrupt()
            progress_advance(self._task.cancelled)

        self.parent.progress_bar = ProgressBar(self.parent, iterations=len(self.get_rows()))
        f = partial(self._score_genes, callback=callback, **kwargs)
        self._task = Task()
        self._task.future = self._executor.submit(f)

        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._end_task)

    def gene_sets_enrichment(self, gs_object, gene_sets, reference_genes):
        """ Run gene sets enrichment.

        :param gs_object:
        :param gene_sets:
        :param reference_genes:

        Note:
            We do not apply filter nor notify view that data is changed. This is done after filters

        """

        for item in self.get_rows():
            genes = [gene.ncbi_id for gene in item.filtered_genes]
            item.gene_set_enrichment(gs_object, gene_sets, set(genes), reference_genes)

    def apply_gene_filters(self, p_val=None, fdr=None, count=None):
        [item.filter_enriched_genes(p_val, fdr, max_gene_count=count) for item in self.get_rows()]
        self.dataChanged.emit(self.createIndex(0, 0), self.createIndex(self.rowCount(0), 0))

    def apply_gene_sets_filters(self, p_val=None, fdr=None, count=None):
        [item.filter_gene_sets(p_val, fdr, max_set_count=count) for item in self.get_rows()]
        self.dataChanged.emit(self.createIndex(0, 0), self.createIndex(self.rowCount(0), 0))
Ejemplo n.º 17
0
class OWExplainPred(OWWidget):

    name = "Explain Predictions"
    description = "Computes attribute contributions to the final prediction with an approximation algorithm for shapely value"
    icon = "icons/ExplainPredictions.svg"
    priority = 200
    gui_error = settings.Setting(0.05)
    gui_p_val = settings.Setting(0.05)

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model, multiple=False)
        sample = Input("Sample", Table)

    class Outputs:
        explanations = Output("Explanations", Table)

    class Error(OWWidget.Error):
        sample_too_big = widget.Msg("Can only explain one sample at the time.")

    class Warning(OWWidget.Warning):
        unknowns_increased = widget.Msg(
            "Number of unknown values increased, Data and Sample domains mismatch.")

    def __init__(self):
        super().__init__()
        self.data = None
        self.model = None
        self.to_explain = None
        self.explanations = None
        self.stop = True
        self.e = None

        self._task = None
        self._executor = ThreadExecutor()

        self.dataview = QTableView(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                   sortingEnabled=True,
                                   selectionMode=QTableView.NoSelection,
                                   focusPolicy=Qt.StrongFocus)

        self.dataview.sortByColumn(2, Qt.DescendingOrder)
        self.dataview.horizontalHeader().setResizeMode(QHeaderView.Stretch)

        domain = Domain([ContinuousVariable("Score"),
                         ContinuousVariable("Error")],
                        metas=[StringVariable(name="Feature"), StringVariable(name="Value")])
        self.placeholder_table_model = TableModel(
            Table.from_domain(domain), parent=None)

        self.dataview.setModel(self.placeholder_table_model)

        info_box = gui.vBox(self.controlArea, "Info")
        self.data_info = gui.widgetLabel(info_box, "Data: N/A")
        self.model_info = gui.widgetLabel(info_box, "Model: N/A")
        self.sample_info = gui.widgetLabel(info_box, "Sample: N/A")

        criteria_box = gui.vBox(self.controlArea, "Stopping criteria")
        self.error_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_error",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error < ",
                                   spinType=float,
                                   callback=self._update_error_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        self.p_val_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_p_val",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error p-value < ",
                                   spinType=float,
                                   callback=self._update_p_val_spin,
                                   controlWidth=80, keyboardTracking=False)

        gui.rubber(self.controlArea)

        self.cancel_button = gui.button(self.controlArea,
                                        self,
                                        "Stop Computation",
                                        callback=self.toggle_button,
                                        autoDefault=True,
                                        tooltip="Stops and restarts computation")
        self.cancel_button.setDisabled(True)

        predictions_box = gui.vBox(self.mainArea, "Model prediction")
        self.predict_info = gui.widgetLabel(predictions_box, "")

        self.mainArea.layout().addWidget(self.dataview)

        self.resize(640, 480)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input 'Data'"""
        self.data = data
        self.explanations = None
        self.data_info.setText("Data: N/A")
        self.e = None
        if data is not None:
            model = TableModel(data, parent=None)
            if data.X.shape[0] == 1:
                inst = "1 instance and "
            else:
                inst = str(data.X.shape[0]) + " instances and "
            if data.X.shape[1] == 1:
                feat = "1 feature "
            else:
                feat = str(data.X.shape[1]) + " features"
            self.data_info.setText("Data: " + inst + feat)

    @Inputs.model
    def set_predictor(self, model):
        """Set input 'Model'"""
        self.model = model
        self.model_info.setText("Model: N/A")
        self.explanations = None
        self.e = None
        if model is not None:
            self.model_info.setText("Model: " + str(model.name))

    @Inputs.sample
    @check_sql_input
    def set_sample(self, sample):
        """Set input 'Sample', checks if size is appropriate"""
        self.to_explain = sample
        self.explanations = None
        self.Error.sample_too_big.clear()
        self.sample_info.setText("Sample: N/A")
        if sample is not None:
            if len(sample.X) != 1:
                self.to_explain = None
                self.Error.sample_too_big()
            else:
                if sample.X.shape[1] == 1:
                    feat = "1 feature"
                else:
                    feat = str(sample.X.shape[1]) + " features"
                self.sample_info.setText("Sample: " + feat)
                if self.e is not None:
                    self.e.saved = False

    def handleNewSignals(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.dataview.setModel(self.placeholder_table_model)
        self.predict_info.setText("")
        self.Warning.unknowns_increased.clear()
        self.stop = True
        self.cancel_button.setText("Stop Computation")
        self.commit_calc_or_output()

    def commit_calc_or_output(self):
        if self.data is not None and self.to_explain is not None:
            self.commit_calc()
        else:
            self.commit_output()


    def commit_calc(self):
        num_nan = np.count_nonzero(np.isnan(self.to_explain.X[0]))

        self.to_explain = self.to_explain.transform(self.data.domain)
        if num_nan != np.count_nonzero(np.isnan(self.to_explain.X[0])):
            self.Warning.unknowns_increased()
        if self.model is not None:
            # calculate contributions
            if self.e is None:
                self.e = ExplainPredictions(self.data,
                                       self.model,
                                       batch_size=min(
                                           len(self.data.X), 500),
                                       p_val=self.gui_p_val,
                                       error=self.gui_error)
            self._task = task = Task()


            def callback(progress):
                nonlocal task
                # update progress bar
                QMetaObject.invokeMethod(
                    self, "set_progress_value", Qt.QueuedConnection, Q_ARG(int, progress))
                if task.canceled:
                    return True
                return False

            def callback_update(table):
                QMetaObject.invokeMethod(
                    self, "update_view", Qt.QueuedConnection, Q_ARG(Orange.data.Table, table))

            def callback_prediction(class_value):
                QMetaObject.invokeMethod(
                    self, "update_model_prediction", Qt.QueuedConnection, Q_ARG(float, class_value))

            self.was_canceled = False
            explain_func = partial(
                self.e.anytime_explain, self.to_explain[0], callback=callback, update_func=callback_update, update_prediction=callback_prediction)

            self.progressBarInit(processEvents=None)
            task.future = self._executor.submit(explain_func)
            task.watcher = FutureWatcher(task.future)
            task.watcher.done.connect(self._task_finished)
            self.cancel_button.setDisabled(False)

    @pyqtSlot(Orange.data.Table)
    def update_view(self, table):
        self.explanations = table
        model = TableModel(table, parent=None)
        header = self.dataview.horizontalHeader()
        model.sort(
            header.sortIndicatorSection(),
            header.sortIndicatorOrder())
        self.dataview.setModel(model)
        self.commit_output()

    @pyqtSlot(float)
    def update_model_prediction(self, value):
        self._print_prediction(value)

    @pyqtSlot(int)
    def set_progress_value(self, value):
        self.progressBarSet(value, processEvents=False)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters:
        ----------
        f: conncurent.futures.Future
            future instance holding the result of learner evaluation
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None

        if not self.was_canceled:
            self.cancel_button.setDisabled(True)

        try:
            results = f.result()
        except Exception as ex:
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occured during evaluation: {!r}".format(ex))

            for key in self.results.keys():
                self.results[key] = None
        else:
            self.update_view(results[1])

        self.progressBarFinished(processEvents=False)


    def commit_output(self):
        """
        Sends best-so-far results forward
        """
        self.Outputs.explanations.send(self.explanations)

    def toggle_button(self):
        if self.stop :
            self.stop = False
            self.cancel_button.setText("Restart Computation")
            self.cancel()
        else:
            self.stop = True
            self.cancel_button.setText("Stop Computation")
            self.commit_calc_or_output()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self.was_canceled = True
            self._task_finished(self._task.future)

    def _print_prediction(self, class_value):
        """
        Parameters
        ----------
        class_value: float 
            Number representing either index of predicted class value, looked up in domain, or predicted value (regression)
        """
        name = self.data.domain.class_vars[0].name
        if isinstance(self.data.domain.class_vars[0], ContinuousVariable):
            self.predict_info.setText(name + ":      " + str(class_value))
        else:
            self.predict_info.setText(
                name + ":      " + self.data.domain.class_vars[0].values[int(class_value)])

    def _update_error_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.error = self.gui_error
        self.handleNewSignals()

    def _update_p_val_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.p_val = self.gui_p_val
        self.handleNewSignals()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
class OWImportImages(widget.OWWidget):
    name = "Import Images"
    description = "Import images from a directory(s)"
    icon = "icons/ImportImages.svg"
    priority = 110

    outputs = [("Data", Orange.data.Table)]

    #: list of recent paths
    recent_paths = settings.Setting([])  # type: List[RecentPath]
    currentPath = settings.Setting(None)

    want_main_area = False
    resizing_enabled = False

    Modality = Qt.ApplicationModal
    # Modality = Qt.WindowModal

    MaxRecentItems = 20

    def __init__(self):
        super().__init__()
        #: widget's runtime state
        self.__state = State.NoState
        self._imageMeta = []
        self._imageCategories = {}

        self.__invalidated = False
        self.__pendingTask = None

        vbox = gui.vBox(self.controlArea)
        hbox = gui.hBox(vbox)
        self.recent_cb = QComboBox(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=16,
        )
        self.recent_cb.activated[int].connect(self.__onRecentActivated)
        icons = standard_icons(self)

        browseaction = QAction(
            "Open/Load Images",
            self,
            iconText="\N{HORIZONTAL ELLIPSIS}",
            icon=icons.dir_open_icon,
            toolTip="Select a directory from which to load the images")
        browseaction.triggered.connect(self.__runOpenDialog)
        reloadaction = QAction("Reload",
                               self,
                               icon=icons.reload_icon,
                               toolTip="Reload current image set")
        reloadaction.triggered.connect(self.reload)
        self.__actions = namespace(
            browse=browseaction,
            reload=reloadaction,
        )

        browsebutton = QPushButton(browseaction.iconText(),
                                   icon=browseaction.icon(),
                                   toolTip=browseaction.toolTip(),
                                   clicked=browseaction.trigger)
        reloadbutton = QPushButton(
            reloadaction.iconText(),
            icon=reloadaction.icon(),
            clicked=reloadaction.trigger,
            default=True,
        )

        hbox.layout().addWidget(self.recent_cb)
        hbox.layout().addWidget(browsebutton)
        hbox.layout().addWidget(reloadbutton)

        self.addActions([browseaction, reloadaction])

        reloadaction.changed.connect(
            lambda: reloadbutton.setEnabled(reloadaction.isEnabled()))
        box = gui.vBox(vbox, "Info")
        self.infostack = QStackedWidget()

        self.info_area = QLabel(text="No image set selected", wordWrap=True)
        self.progress_widget = QProgressBar(minimum=0, maximum=0)
        self.cancel_button = QPushButton(
            "Cancel",
            icon=icons.cancel_icon,
        )
        self.cancel_button.clicked.connect(self.cancel)

        w = QWidget()
        vlayout = QVBoxLayout()
        vlayout.setContentsMargins(0, 0, 0, 0)
        hlayout = QHBoxLayout()
        hlayout.setContentsMargins(0, 0, 0, 0)

        hlayout.addWidget(self.progress_widget)
        hlayout.addWidget(self.cancel_button)
        vlayout.addLayout(hlayout)

        self.pathlabel = TextLabel()
        self.pathlabel.setTextElideMode(Qt.ElideMiddle)
        self.pathlabel.setAttribute(Qt.WA_MacSmallSize)

        vlayout.addWidget(self.pathlabel)
        w.setLayout(vlayout)

        self.infostack.addWidget(self.info_area)
        self.infostack.addWidget(w)

        box.layout().addWidget(self.infostack)

        self.__initRecentItemsModel()
        self.__invalidated = True
        self.__executor = ThreadExecutor(self)

        QApplication.postEvent(self, QEvent(RuntimeEvent.Init))

    def __initRecentItemsModel(self):
        if self.currentPath is not None and \
                not os.path.isdir(self.currentPath):
            self.currentPath = None

        recent_paths = []
        for item in self.recent_paths:
            if os.path.isdir(item.abspath):
                recent_paths.append(item)
        recent_paths = recent_paths[:OWImportImages.MaxRecentItems]
        recent_model = self.recent_cb.model()
        for pathitem in recent_paths:
            item = RecentPath_asqstandarditem(pathitem)
            recent_model.appendRow(item)

        self.recent_paths = recent_paths

        if self.currentPath is not None and \
                os.path.isdir(self.currentPath) and self.recent_paths and \
                os.path.samefile(self.currentPath, self.recent_paths[0].abspath):
            self.recent_cb.setCurrentIndex(0)
        else:
            self.currentPath = None
            self.recent_cb.setCurrentIndex(-1)
        self.__actions.reload.setEnabled(self.currentPath is not None)

    def customEvent(self, event):
        """Reimplemented."""
        if event.type() == RuntimeEvent.Init:
            if self.__invalidated:
                try:
                    self.start()
                finally:
                    self.__invalidated = False

        super().customEvent(event)

    def __runOpenDialog(self):
        startdir = os.path.expanduser("~/")
        if self.recent_paths:
            startdir = self.recent_paths[0].abspath

        if OWImportImages.Modality == Qt.WindowModal:
            dlg = QFileDialog(
                self,
                "Select Top Level Directory",
                startdir,
                acceptMode=QFileDialog.AcceptOpen,
                modal=True,
            )
            dlg.setFileMode(QFileDialog.Directory)
            dlg.setOption(QFileDialog.ShowDirsOnly)
            dlg.setDirectory(startdir)
            dlg.setAttribute(Qt.WA_DeleteOnClose)

            @dlg.accepted.connect
            def on_accepted():
                dirpath = dlg.selectedFiles()
                if dirpath:
                    self.setCurrentPath(dirpath[0])
                    self.start()

            dlg.open()
        else:
            dirpath = QFileDialog.getExistingDirectory(
                self, "Select Top Level Directory", startdir)
            if dirpath:
                self.setCurrentPath(dirpath)
                self.start()

    def __onRecentActivated(self, index):
        item = self.recent_cb.itemData(index)
        if item is None:
            return
        assert isinstance(item, RecentPath)
        self.setCurrentPath(item.abspath)
        self.start()

    def __updateInfo(self):
        if self.__state == State.NoState:
            text = "No image set selected"
        elif self.__state == State.Processing:
            text = "Processing"
        elif self.__state == State.Done:
            nvalid = sum(imeta.isvalid for imeta in self._imageMeta)
            ncategories = len(self._imageCategories)
            if ncategories < 2:
                text = "{} images".format(nvalid)
            else:
                text = "{} images / {} categories".format(nvalid, ncategories)
        elif self.__state == State.Cancelled:
            text = "Cancelled"
        elif self.__state == State.Error:
            text = "Error state"
        else:
            assert False

        self.info_area.setText(text)

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

    def setCurrentPath(self, path):
        """
        Set the current root image path to path

        If the path does not exists or is not a directory the current path
        is left unchanged

        Parameters
        ----------
        path : str
            New root import path.

        Returns
        -------
        status : bool
            True if the current root import path was successfully
            changed to path.
        """
        if self.currentPath is not None and path is not None and \
                os.path.isdir(self.currentPath) and os.path.isdir(path) and \
                os.path.samefile(self.currentPath, path):
            return True

        if not os.path.exists(path):
            warnings.warn("'{}' does not exist".format(path), UserWarning)
            return False
        elif not os.path.isdir(path):
            warnings.warn("'{}' is not a directory".format(path), UserWarning)
            return False

        newindex = self.addRecentPath(path)
        self.recent_cb.setCurrentIndex(newindex)
        if newindex >= 0:
            self.currentPath = path
        else:
            self.currentPath = None
        self.__actions.reload.setEnabled(self.currentPath is not None)

        if self.__state == State.Processing:
            self.cancel()

        return True

    def addRecentPath(self, path):
        """
        Prepend a path entry to the list of recent paths

        If an entry with the same path already exists in the recent path
        list it is moved to the first place

        Parameters
        ----------
        path : str
        """
        existing = None
        for pathitem in self.recent_paths:
            if os.path.samefile(pathitem.abspath, path):
                existing = pathitem
                break

        model = self.recent_cb.model()

        if existing is not None:
            selected_index = self.recent_paths.index(existing)
            assert model.item(selected_index).data(Qt.UserRole) is existing
            self.recent_paths.remove(existing)
            row = model.takeRow(selected_index)
            self.recent_paths.insert(0, existing)
            model.insertRow(0, row)
        else:
            item = RecentPath(path, None, None)
            self.recent_paths.insert(0, item)
            model.insertRow(0, RecentPath_asqstandarditem(item))
        return 0

    def __setRuntimeState(self, state):
        assert state in State
        self.setBlocking(state == State.Processing)
        message = ""
        if state == State.Processing:
            assert self.__state in [
                State.Done, State.NoState, State.Error, State.Cancelled
            ]
            message = "Processing"
        elif state == State.Done:
            assert self.__state == State.Processing
        elif state == State.Cancelled:
            assert self.__state == State.Processing
            message = "Cancelled"
        elif state == State.Error:
            message = "Error during processing"
        elif state == State.NoState:
            message = ""
        else:
            assert False

        self.__state = state

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

        self.setStatusMessage(message)
        self.__updateInfo()

    def reload(self):
        """
        Restart the image scan task
        """
        if self.__state == State.Processing:
            self.cancel()

        self._imageMeta = []
        self._imageCategories = {}
        self.start()

    def start(self):
        """
        Start/execute the image indexing operation
        """
        self.error()

        self.__invalidated = False
        if self.currentPath is None:
            return

        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            log.info("Starting a new task while one is in progress. "
                     "Cancel the existing task (dir:'{}')".format(
                         self.__pendingTask.startdir))
            self.cancel()

        startdir = self.currentPath

        self.__setRuntimeState(State.Processing)

        report_progress = methodinvoke(self, "__onReportProgress", (object, ))

        task = ImageScan(startdir, report_progress=report_progress)

        # collect the task state in one convenient place
        self.__pendingTask = taskstate = namespace(
            task=task,
            startdir=startdir,
            future=None,
            watcher=None,
            cancelled=False,
            cancel=None,
        )

        def cancel():
            # Cancel the task and disconnect
            if taskstate.future.cancel():
                pass
            else:
                taskstate.task.cancelled = True
                taskstate.cancelled = True
                try:
                    taskstate.future.result(timeout=3)
                except UserInterruptError:
                    pass
                except TimeoutError:
                    log.info("The task did not stop in in a timely manner")
            taskstate.watcher.finished.disconnect(self.__onRunFinished)

        taskstate.cancel = cancel

        def run_image_scan_task_interupt():
            try:
                return task.run()
            except UserInterruptError:
                # Suppress interrupt errors, so they are not logged
                return

        taskstate.future = self.__executor.submit(run_image_scan_task_interupt)
        taskstate.watcher = FutureWatcher(taskstate.future)
        taskstate.watcher.finished.connect(self.__onRunFinished)

    @Slot()
    def __onRunFinished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__state == State.Processing
        assert self.__pendingTask is not None
        assert self.sender() is self.__pendingTask.watcher
        assert self.__pendingTask.future.done()
        task = self.__pendingTask
        self.__pendingTask = None

        try:
            image_meta = task.future.result()
        except Exception as err:
            sys.excepthook(*sys.exc_info())
            state = State.Error
            image_meta = []
            self.error(traceback.format_exc())
        else:
            state = State.Done
            self.error()

        categories = {}

        for imeta in image_meta:
            # derive categories from the path relative to the starting dir
            dirname = os.path.dirname(imeta.path)
            relpath = os.path.relpath(dirname, task.startdir)
            categories[dirname] = relpath

        self._imageMeta = image_meta
        self._imageCategories = categories

        self.__setRuntimeState(state)
        self.commit()

    def cancel(self):
        """
        Cancel current pending task (if any).
        """
        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            self.__pendingTask.cancel()
            self.__pendingTask = None
            self.__setRuntimeState(State.Cancelled)

    @Slot(object)
    def __onReportProgress(self, arg):
        # report on scan progress from a worker thread
        # arg must be a namespace(count: int, lastpath: str)
        assert QThread.currentThread() is self.thread()
        if self.__state == State.Processing:
            self.pathlabel.setText(prettyfypath(arg.lastpath))

    def commit(self):
        """
        Create and commit a Table from the collected image meta data.
        """
        if self._imageMeta:
            categories = self._imageCategories
            if len(categories) > 1:
                cat_var = Orange.data.DiscreteVariable(
                    "category", values=list(sorted(categories.values())))
            else:
                cat_var = None
            # Image name (file basename without the extension)
            imagename_var = Orange.data.StringVariable("image name")
            # Full fs path
            image_var = Orange.data.StringVariable("image")
            image_var.attributes["type"] = "image"
            # file size/width/height
            size_var = Orange.data.ContinuousVariable("size",
                                                      number_of_decimals=0)
            width_var = Orange.data.ContinuousVariable("width",
                                                       number_of_decimals=0)
            height_var = Orange.data.ContinuousVariable("height",
                                                        number_of_decimals=0)
            domain = Orange.data.Domain(
                [], [cat_var] if cat_var is not None else [],
                [imagename_var, image_var, size_var, width_var, height_var])
            cat_data = []
            meta_data = []

            for imgmeta in self._imageMeta:
                if imgmeta.isvalid:
                    if cat_var is not None:
                        category = categories.get(os.path.dirname(
                            imgmeta.path))
                        cat_data.append([cat_var.to_val(category)])
                    else:
                        cat_data.append([])
                    basename = os.path.basename(imgmeta.path)
                    imgname, _ = os.path.splitext(basename)

                    meta_data.append([
                        imgname, imgmeta.path, imgmeta.size, imgmeta.width,
                        imgmeta.height
                    ])

            cat_data = numpy.array(cat_data, dtype=float)
            meta_data = numpy.array(meta_data, dtype=object)
            table = Orange.data.Table.from_numpy(
                domain, numpy.empty((len(cat_data), 0), dtype=float), cat_data,
                meta_data)
        else:
            table = None

        self.send("Data", table)

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=True)
Ejemplo n.º 19
0
class CNNM(OWWidget):
    name = "M CNN"
    description = ""
    # icon = "icons/robot.svg"

    want_main_area = True

    class Inputs:
        data = Input('Data', ImageDataBunch, default=True)

    def __init__(self):
        super().__init__()
        self.learn = None

        # train_button = gui.button(self.controlArea, self, "开始训练", callback=self.train)
        self.label = gui.label(self.mainArea, self, "模型结构")

        #: The current evaluating task (if any)
        self._task = None  # type: Optional[Task]
        #: An executor we use to submit learner evaluations into a thread pool
        self._executor = ThreadExecutor()

        self.model = nn.Sequential(
            self.conv(1, 8),  # 14
            nn.BatchNorm2d(8),
            nn.ReLU(),
            self.conv(8, 16),  # 7
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self.conv(16, 32),  # 4
            nn.BatchNorm2d(32),
            nn.ReLU(),
            self.conv(32, 16),  # 2
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self.conv(16, 10),  # 1
            nn.BatchNorm2d(10),
            Flatten()  # remove (1,1) grid
        )

    def handleNewSignals(self):
        self._update()

    def _update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        if self.data is None:
            return
        # collect all learners for which results have not yet been computed
        if not self.learn:
            return

        # setup the task state
        self._task = task = Task()
        # The learning_curve[_with_test_data] also takes a callback function
        # to report the progress. We instrument this callback to both invoke
        # the appropriate slots on this widget for reporting the progress
        # (in a thread safe manner) and to implement cooperative cancellation.
        set_progress = methodinvoke(self, "setProgressValue", (float,))

        def callback(finished):
            # check if the task has been cancelled and raise an exception
            # from within. This 'strategy' can only be used with code that
            # properly cleans up after itself in the case of an exception
            # (does not leave any global locks, opened file descriptors, ...)
            if task.cancelled:
                raise KeyboardInterrupt()
            set_progress(finished * 100)

        self.progressBarInit()
        # Submit the evaluation function to the executor and fill in the
        # task with the resultant Future.
        # task.future = self._executor.submit(self.learn.fit_one_cycle(1))

        with progress_disabled_ctx(self.learn) as learn:
            fit_model = partial(my_fit, learn, 1, callback=callback)
            task.future = self._executor.submit(fit_model)
            # Setup the FutureWatcher to notify us of completion
            task.watcher = FutureWatcher(task.future)
            # by using FutureWatcher we ensure `_task_finished` slot will be
            # called from the main GUI thread by the Qt's event loop
            task.watcher.done.connect(self._task_finished)

    @pyqtSlot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the result of learner evaluation.
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        # try:
        #     result = f.result()  # type: List[Results]
        # except Exception as ex:
        #     # Log the exception with a traceback
        #     log = logging.getLogger()
        #     log.exception(__name__, exc_info=True)
        #     self.error("Exception occurred during evaluation: {!r}".format(ex))
        #     # clear all results
        #     self.result= None
        # else:
        print(self.learn.validate())
            # ... and update self.results

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self._task = None

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    def conv(self, ni, nf):
        return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)

    def train(self):
        if self.learn is None:
            return
        self.learn.fit_one_cycle(3)

    @Inputs.data
    def set_data(self, data):
        if data is not None:
            self.data = data
            self.learn = Learner(self.data, self.model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy,
                                 add_time=False, bn_wd=False, silent=True)
            self.label.setText(self.learn.summary())
        else:
            self.data = None
Ejemplo n.º 20
0
class AgentTrainMixin():
    train_results = None
    initial_train_results = None

    trained_episodes = 0
    initial_trained_episodes = 0

    memory = None
    initial_memory = None

    _progress = 0

    def train(self, episodes, seconds, ow_widget, ow_widget_on_finish):
        self.ow_widget = ow_widget
        self.ow_widget_on_finish = ow_widget_on_finish

        self._executor = ThreadExecutor()

        self.ow_widget.progressBarInit()

        def on_progress(self, progress):
            self.on_progress(progress)

        def on_finish(self):
            self.on_finish()

        self._executor.submit(
            partial(self.train_task, episodes, seconds, on_progress,
                    on_finish))

    def train_episode(self):
        done = False

        steps_to_finish = 0
        total_reward = 0

        state = self.environment.reset()

        while not done:
            # pylint: disable=assignment-from-no-return
            action, action_info = self.train_action(state)

            new_state, reward, done, _info = self.environment.step(action)

            self.process_reward(state, action, reward, new_state)

            state = new_state

            steps_to_finish += 1
            total_reward += reward

        return {
            'steps_to_finish': steps_to_finish,
            'total_reward': total_reward,
            'last_action_info': action_info
        }

    def process_reward(self, state, action, reward, new_state):
        pass

    def train_action(self, state):
        pass

    def on_progress(self, progress):
        progress = int(progress)
        # Performance reasons: only update the
        # progress when is realy necessary.
        if progress != self._progress:
            self._progress = progress
            self.ow_widget.progressBarSet(progress)

    def on_finish(self):
        self.ow_widget.progressBarFinished()
        self.ow_widget_on_finish()

    @staticmethod
    def spend_seconds(started_time):
        return time.time() - started_time

    def has_available_time(self, started_time, seconds):
        return self.spend_seconds(started_time) < seconds

    def current_progress(self, started_time, seconds, episodes, interations):
        progress = self._progress

        estimated_seconds = seconds

        spend_seconds = self.spend_seconds(started_time)

        if episodes > 0 and spend_seconds > 0:
            interation_mean_seconds = spend_seconds / interations
            estimated_seconds += episodes * interation_mean_seconds

        if estimated_seconds > 0.0:
            progress = (spend_seconds / estimated_seconds) * 100

            if progress >= 100.0:
                progress = 99.999

        return progress

    def train_task(self, episodes, seconds, on_progress, on_finish):
        episode = 1
        interations = 0

        started_time = time.time()

        self.trained_episodes = self.initial_trained_episodes
        self.train_results = deepcopy(self.initial_train_results)
        self.memory = deepcopy(self.initial_memory)

        while episode <= episodes or self.has_available_time(
                started_time, seconds):
            interations += 1

            on_progress(
                self,
                self.current_progress(started_time, seconds, episodes,
                                      interations))

            self.trained_episodes += 1

            # pylint: disable=assignment-from-no-return
            result = self.train_episode()

            self.train_results = np.append(self.train_results, result)

            if not self.has_available_time(started_time, seconds):
                episode += 1

        on_finish(self)
Ejemplo n.º 21
0
class OWImportImages(widget.OWWidget):
    name = "Import Images"
    description = "Import images from a directory(s)"
    icon = "icons/ImportImages.svg"
    priority = 110

    outputs = [("Data", Orange.data.Table)]

    #: list of recent paths
    recent_paths = settings.Setting([])  # type: List[RecentPath]

    want_main_area = False
    resizing_enabled = False

    Modality = Qt.ApplicationModal
    # Modality = Qt.WindowModal

    MaxRecentItems = 20

    def __init__(self):
        super().__init__()
        #: widget's runtime state
        self.__state = State.NoState
        self.data = None
        self._n_image_categories = 0
        self._n_image_data = 0
        self._n_skipped = 0

        self.__invalidated = False
        self.__pendingTask = None

        vbox = gui.vBox(self.controlArea)
        hbox = gui.hBox(vbox)
        self.recent_cb = QComboBox(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=16,
            acceptDrops=True
        )
        self.recent_cb.installEventFilter(self)
        self.recent_cb.activated[int].connect(self.__onRecentActivated)
        icons = standard_icons(self)

        browseaction = QAction(
            "Open/Load Images", self,
            iconText="\N{HORIZONTAL ELLIPSIS}",
            icon=icons.dir_open_icon,
            toolTip="Select a directory from which to load the images"
        )
        browseaction.triggered.connect(self.__runOpenDialog)
        reloadaction = QAction(
            "Reload", self,
            icon=icons.reload_icon,
            toolTip="Reload current image set"
        )
        reloadaction.triggered.connect(self.reload)
        self.__actions = namespace(
            browse=browseaction,
            reload=reloadaction,
        )

        browsebutton = QPushButton(
            browseaction.iconText(),
            icon=browseaction.icon(),
            toolTip=browseaction.toolTip(),
            clicked=browseaction.trigger
        )
        reloadbutton = QPushButton(
            reloadaction.iconText(),
            icon=reloadaction.icon(),
            clicked=reloadaction.trigger,
            default=True,
        )

        hbox.layout().addWidget(self.recent_cb)
        hbox.layout().addWidget(browsebutton)
        hbox.layout().addWidget(reloadbutton)

        self.addActions([browseaction, reloadaction])

        reloadaction.changed.connect(
            lambda: reloadbutton.setEnabled(reloadaction.isEnabled())
        )
        box = gui.vBox(vbox, "Info")
        self.infostack = QStackedWidget()

        self.info_area = QLabel(
            text="No image set selected",
            wordWrap=True
        )
        self.progress_widget = QProgressBar(
            minimum=0, maximum=0
        )
        self.cancel_button = QPushButton(
            "Cancel", icon=icons.cancel_icon,
        )
        self.cancel_button.clicked.connect(self.cancel)

        w = QWidget()
        vlayout = QVBoxLayout()
        vlayout.setContentsMargins(0, 0, 0, 0)
        hlayout = QHBoxLayout()
        hlayout.setContentsMargins(0, 0, 0, 0)

        hlayout.addWidget(self.progress_widget)
        hlayout.addWidget(self.cancel_button)
        vlayout.addLayout(hlayout)

        self.pathlabel = TextLabel()
        self.pathlabel.setTextElideMode(Qt.ElideMiddle)
        self.pathlabel.setAttribute(Qt.WA_MacSmallSize)

        vlayout.addWidget(self.pathlabel)
        w.setLayout(vlayout)

        self.infostack.addWidget(self.info_area)
        self.infostack.addWidget(w)

        box.layout().addWidget(self.infostack)

        self.__initRecentItemsModel()
        self.__invalidated = True
        self.__executor = ThreadExecutor(self)

        QApplication.postEvent(self, QEvent(RuntimeEvent.Init))

    def __initRecentItemsModel(self):
        self._relocate_recent_files()
        recent_paths = []
        for item in self.recent_paths:
            recent_paths.append(item)
        recent_paths = recent_paths[:OWImportImages.MaxRecentItems]
        recent_model = self.recent_cb.model()
        recent_model.clear()

        for pathitem in recent_paths:
            item = RecentPath_asqstandarditem(pathitem)
            recent_model.appendRow(item)

        self.recent_paths = recent_paths

        if self.recent_paths and os.path.isdir(self.recent_paths[0].abspath):
            self.recent_cb.setCurrentIndex(0)
            self.__actions.reload.setEnabled(True)
        else:
            self.recent_cb.setCurrentIndex(-1)
            self.__actions.reload.setEnabled(False)

    def customEvent(self, event):
        """Reimplemented."""
        if event.type() == RuntimeEvent.Init:
            if self.__invalidated:
                try:
                    self.start()
                finally:
                    self.__invalidated = False

        super().customEvent(event)

    def __runOpenDialog(self):
        startdir = os.path.expanduser("~/")
        if self.recent_paths:
            startdir = os.path.dirname(self.recent_paths[0].abspath)

        if OWImportImages.Modality == Qt.WindowModal:
            dlg = QFileDialog(
                self, "Select Top Level Directory", startdir,
                acceptMode=QFileDialog.AcceptOpen,
                modal=True,
            )
            dlg.setFileMode(QFileDialog.Directory)
            dlg.setOption(QFileDialog.ShowDirsOnly)
            dlg.setDirectory(startdir)
            dlg.setAttribute(Qt.WA_DeleteOnClose)

            @dlg.accepted.connect
            def on_accepted():
                dirpath = dlg.selectedFiles()
                if dirpath:
                    self.setCurrentPath(dirpath[0])
                    self.start()
            dlg.open()
        else:
            dirpath = QFileDialog.getExistingDirectory(
                self, "Select Top Level Directory", startdir
            )
            if dirpath:
                self.setCurrentPath(dirpath)
                self.start()

    def __onRecentActivated(self, index):
        item = self.recent_cb.itemData(index)
        if item is None:
            return
        assert isinstance(item, RecentPath)
        self.setCurrentPath(item.abspath)
        self.start()

    def __updateInfo(self):
        if self.__state == State.NoState:
            text = "No image set selected"
        elif self.__state == State.Processing:
            text = "Processing"
        elif self.__state == State.Done:
            nvalid = self._n_image_data
            ncategories = self._n_image_categories
            n_skipped = self._n_skipped
            if ncategories < 2:
                text = "{} image{}".format(nvalid, "s" if nvalid != 1 else "")
            else:
                text = "{} images / {} categories".format(nvalid, ncategories)
            if n_skipped > 0:
                text = text + ", {} skipped".format(n_skipped)
        elif self.__state == State.Cancelled:
            text = "Cancelled"
        elif self.__state == State.Error:
            text = "Error state"
        else:
            assert False

        self.info_area.setText(text)

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

    def setCurrentPath(self, path):
        """
        Set the current root image path to path

        If the path does not exists or is not a directory the current path
        is left unchanged

        Parameters
        ----------
        path : str
            New root import path.

        Returns
        -------
        status : bool
            True if the current root import path was successfully
            changed to path.
        """
        if self.recent_paths and path is not None and \
                os.path.isdir(self.recent_paths[0].abspath) and os.path.isdir(path) \
                and os.path.samefile(os.path.isdir(self.recent_paths[0].abspath), path):
            return True

        success = True
        error = None
        if path is not None:
            if not os.path.exists(path):
                error = "'{}' does not exist".format(path)
                path = None
                success = False
            elif not os.path.isdir(path):
                error = "'{}' is not a directory".format(path)
                path = None
                success = False

        if error is not None:
            self.error(error)
            warnings.warn(error, UserWarning, stacklevel=3)
        else:
            self.error()

        if path is not None:
            newindex = self.addRecentPath(path)
            self.recent_cb.setCurrentIndex(newindex)

        self.__actions.reload.setEnabled(len(self.recent_paths) > 0)

        if self.__state == State.Processing:
            self.cancel()

        return success

    def _search_paths(self):
        basedir = self.workflowEnv().get("basedir", None)
        if basedir is None:
            return []
        return [("basedir", basedir)]

    def addRecentPath(self, path):
        """
        Prepend a path entry to the list of recent paths

        If an entry with the same path already exists in the recent path
        list it is moved to the first place

        Parameters
        ----------
        path : str
        """
        existing = None
        for pathitem in self.recent_paths:
            try:
                if os.path.samefile(pathitem.abspath, path):
                    existing = pathitem
                    break
            except FileNotFoundError:
                # file not found if the `pathitem.abspath` no longer exists
                pass

        model = self.recent_cb.model()

        if existing is not None:
            selected_index = self.recent_paths.index(existing)
            assert model.item(selected_index).data(Qt.UserRole) is existing
            self.recent_paths.remove(existing)
            row = model.takeRow(selected_index)
            self.recent_paths.insert(0, existing)
            model.insertRow(0, row)
        else:
            item = RecentPath.create(path, self._search_paths())
            self.recent_paths.insert(0, item)
            model.insertRow(0, RecentPath_asqstandarditem(item))
        return 0

    def __setRuntimeState(self, state):
        assert state in State
        self.setBlocking(state == State.Processing)
        message = ""
        if state == State.Processing:
            assert self.__state in [State.Done,
                                    State.NoState,
                                    State.Error,
                                    State.Cancelled]
            message = "Processing"
        elif state == State.Done:
            assert self.__state == State.Processing
        elif state == State.Cancelled:
            assert self.__state == State.Processing
            message = "Cancelled"
        elif state == State.Error:
            message = "Error during processing"
        elif state == State.NoState:
            message = ""
        else:
            assert False

        self.__state = state

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

        self.setStatusMessage(message)
        self.__updateInfo()

    def reload(self):
        """
        Restart the image scan task
        """
        if self.__state == State.Processing:
            self.cancel()

        self.data = None
        self.start()

    def start(self):
        """
        Start/execute the image indexing operation
        """
        self.error()

        self.__invalidated = False
        if not self.recent_paths:
            return

        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            log.info("Starting a new task while one is in progress. "
                     "Cancel the existing task (dir:'{}')"
                     .format(self.__pendingTask.startdir))
            self.cancel()

        startdir = self.recent_paths[0].abspath

        self.__setRuntimeState(State.Processing)

        report_progress = methodinvoke(
            self, "__onReportProgress", (object,))

        task = ImportImages(report_progress=report_progress)

        # collect the task state in one convenient place
        self.__pendingTask = taskstate = namespace(
            task=task,
            startdir=startdir,
            future=None,
            watcher=None,
            cancelled=False,
            cancel=None,
        )

        def cancel():
            # Cancel the task and disconnect
            if taskstate.future.cancel():
                pass
            else:
                taskstate.task.cancelled = True
                taskstate.cancelled = True
                try:
                    taskstate.future.result(timeout=3)
                except UserInterruptError:
                    pass
                except TimeoutError:
                    log.info("The task did not stop in in a timely manner")
            taskstate.watcher.finished.disconnect(self.__onRunFinished)

        taskstate.cancel = cancel

        def run_image_scan_task_interupt():
            try:
                return task(startdir)
            except UserInterruptError:
                # Suppress interrupt errors, so they are not logged
                return

        taskstate.future = self.__executor.submit(run_image_scan_task_interupt)
        taskstate.watcher = FutureWatcher(taskstate.future)
        taskstate.watcher.finished.connect(self.__onRunFinished)

    @Slot()
    def __onRunFinished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__state == State.Processing
        assert self.__pendingTask is not None
        assert self.sender() is self.__pendingTask.watcher
        assert self.__pendingTask.future.done()
        task = self.__pendingTask
        self.__pendingTask = None

        try:
            data, n_skipped = task.future.result()
        except Exception:
            sys.excepthook(*sys.exc_info())
            state = State.Error
            data = None
            n_skipped = 0
            self.error(traceback.format_exc())
        else:
            state = State.Done
            self.error()

        if data:
            self._n_image_data = len(data)
            self._n_image_categories = len(data.domain.class_var.values)\
                if data.domain.class_var else 0
        else:
            self._n_image_data, self._n_image_categories = 0, 0

        self.data = data
        self._n_skipped = n_skipped

        self.__setRuntimeState(state)
        self.commit()

    def cancel(self):
        """
        Cancel current pending task (if any).
        """
        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            self.__pendingTask.cancel()
            self.__pendingTask = None
            self.__setRuntimeState(State.Cancelled)

    @Slot(object)
    def __onReportProgress(self, arg):
        # report on scan progress from a worker thread
        # arg must be a namespace(count: int, lastpath: str)
        assert QThread.currentThread() is self.thread()
        if self.__state == State.Processing:
            self.pathlabel.setText(prettyfypath(arg.lastpath))

    def commit(self):
        """
        Commit a Table from the collected image meta data.
        """
        self.send("Data", self.data)

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=True)
        self.__invalidated = False

    def eventFilter(self, receiver, event):
        # re-implemented from QWidget
        # intercept and process drag drop events on the recent directory
        # selection combo box
        def dirpath(event):
            # type: (QDropEvent) -> Optional[str]
            """Return the directory from a QDropEvent."""
            data = event.mimeData()
            urls = data.urls()
            if len(urls) == 1:
                url = urls[0]
                path = url.toLocalFile()
                if path.endswith("/"):
                    path = path[:-1]  # remove last /
                if os.path.isdir(path):
                    return path
            return None

        if receiver is self.recent_cb and \
                event.type() in {QEvent.DragEnter, QEvent.DragMove,
                                 QEvent.Drop}:
            assert isinstance(event, QDropEvent)
            path = dirpath(event)
            if path is not None and event.possibleActions() & Qt.LinkAction:
                event.setDropAction(Qt.LinkAction)
                event.accept()
                if event.type() == QEvent.Drop:
                    self.setCurrentPath(path)
                    self.start()
            else:
                event.ignore()
            return True

        return super().eventFilter(receiver, event)

    def _relocate_recent_files(self):
        search_paths = self._search_paths()
        rec = []
        for recent in self.recent_paths:
            kwargs = dict(
                title=recent.title, sheet=recent.sheet,
                file_format=recent.file_format)
            resolved = recent.resolve(search_paths)
            if resolved is not None:
                rec.append(
                    RecentPath.create(resolved.abspath, search_paths, **kwargs))
            else:
                rec.append(recent)
        # change the list in-place for the case the widgets wraps this list
        self.recent_paths[:] = rec

    def workflowEnvChanged(self, key, value, oldvalue):
        """
        Function called when environment changes (e.g. while saving the scheme)
        It make sure that all environment connected values are modified
        (e.g. relative file paths are changed)
        """
        self.__initRecentItemsModel()
Ejemplo n.º 22
0
class OWNNLearner(OWBaseLearner):
    name = "Neural Network"
    description = "A multi-layer perceptron (MLP) algorithm with " \
                  "backpropagation."
    icon = "icons/NN.svg"
    priority = 90

    LEARNER = NNLearner

    activation = ["identity", "logistic", "tanh", "relu"]
    act_lbl = ["Identity", "Logistic", "tanh", "ReLu"]
    solver = ["lbfgs", "sgd", "adam"]
    solv_lbl = ["L-BFGS-B", "SGD", "Adam"]

    learner_name = Setting("Neural Network")
    hidden_layers_input = Setting("100,")
    activation_index = Setting(3)
    solver_index = Setting(2)
    alpha = Setting(0.0001)
    max_iterations = Setting(200)

    def add_main_layout(self):
        box = gui.vBox(self.controlArea, "Network")
        self.hidden_layers_edit = gui.lineEdit(
            box, self, "hidden_layers_input", label="Neurons per hidden layer:",
            orientation=Qt.Horizontal, callback=self.settings_changed,
            tooltip="A list of integers defining neurons. Length of list "
                    "defines the number of layers. E.g. 4, 2, 2, 3.",
            placeholderText="e.g. 100,")
        self.activation_combo = gui.comboBox(
            box, self, "activation_index", orientation=Qt.Horizontal,
            label="Activation:", items=[i for i in self.act_lbl],
            callback=self.settings_changed)
        self.solver_combo = gui.comboBox(
            box, self, "solver_index", orientation=Qt.Horizontal,
            label="Solver:", items=[i for i in self.solv_lbl],
            callback=self.settings_changed)
        self.alpha_spin = gui.doubleSpin(
            box, self, "alpha", 1e-5, 1.0, 1e-2,
            label="Alpha:", decimals=5, alignment=Qt.AlignRight,
            callback=self.settings_changed, controlWidth=80)
        self.max_iter_spin = gui.spin(
            box, self, "max_iterations", 10, 10000, step=10,
            label="Max iterations:", orientation=Qt.Horizontal,
            alignment=Qt.AlignRight, callback=self.settings_changed,
            controlWidth=80)

    def setup_layout(self):
        super().setup_layout()

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # just a test cancel button
        gui.button(self.controlArea, self, "Cancel", callback=self.cancel)

    def create_learner(self):
        return self.LEARNER(
            hidden_layer_sizes=self.get_hidden_layers(),
            activation=self.activation[self.activation_index],
            solver=self.solver[self.solver_index],
            alpha=self.alpha,
            max_iter=self.max_iterations,
            preprocessors=self.preprocessors)

    def get_learner_parameters(self):
        return (("Hidden layers", ', '.join(map(str, self.get_hidden_layers()))),
                ("Activation", self.act_lbl[self.activation_index]),
                ("Solver", self.solv_lbl[self.solver_index]),
                ("Alpha", self.alpha),
                ("Max iterations", self.max_iterations))

    def get_hidden_layers(self):
        layers = tuple(map(int, re.findall(r'\d+', self.hidden_layers_input)))
        if not layers:
            layers = (100,)
            self.hidden_layers_edit.setText("100,")
        return layers

    def update_model(self):
        self.show_fitting_failed(None)
        self.model = None
        if self.check_data():
            self.__update()
        else:
            self.Outputs.model.send(self.model)

    @Slot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    def __update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        max_iter = self.learner.kwargs["max_iter"]

        # Setup the task state
        task = Task()
        lastemitted = 0.

        def callback(iteration):
            nonlocal task  # type: Task
            nonlocal lastemitted
            if task.isInterruptionRequested():
                raise CancelTaskException()
            progress = round(iteration / max_iter * 100)
            if progress != lastemitted:
                task.emitProgressUpdate(progress)
                lastemitted = progress

        # copy to set the callback so that the learner output is not modified
        # (currently we can not pass callbacks to learners __call__)
        learner = copy.copy(self.learner)
        learner.callback = callback

        def build_model(data, learner):
            try:
                return learner(data)
            except CancelTaskException:
                return None

        build_model_func = partial(build_model, self.data, learner)

        task.setFuture(self._executor.submit(build_model_func))
        task.done.connect(self._task_finished)
        task.progressChanged.connect(self.setProgressValue)

        self._task = task
        self.progressBarInit()
        self.setBlocking(True)

    @Slot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the built model
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()
        self._task.deleteLater()
        self._task = None
        self.setBlocking(False)
        self.progressBarFinished()

        try:
            self.model = f.result()
        except Exception as ex:  # pylint: disable=broad-except
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.model = None
            self.show_fitting_failed(ex)
        else:
            self.model.name = self.learner_name
            self.model.instances = self.data
            self.Outputs.model.send(self.model)

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect from the task
            self._task.done.disconnect(self._task_finished)
            self._task.progressChanged.disconnect(self.setProgressValue)
            self._task.deleteLater()
            self._task = None

        self.progressBarFinished()
        self.setBlocking(False)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Ejemplo n.º 23
0
class OWKMeans(widget.OWWidget):
    name = "k-Means"
    description = "k-Means clustering algorithm with silhouette-based " \
                  "quality estimation."
    icon = "icons/KMeans.svg"
    priority = 2100
    keywords = ["kmeans", "clustering"]

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        annotated_data = Output(
            ANNOTATED_DATA_SIGNAL_NAME, Table, default=True,
            replaces=["Annotated Data"]
        )
        centroids = Output("Centroids", Table)

    class Error(widget.OWWidget.Error):
        failed = widget.Msg("Clustering failed\nError: {}")
        not_enough_data = widget.Msg(
            "Too few ({}) unique data instances for {} clusters"
        )
        no_attributes = widget.Msg("Data is missing features.")

    class Warning(widget.OWWidget.Warning):
        no_silhouettes = widget.Msg(
            "Silhouette scores are not computed for >{} samples".format(
                SILHOUETTE_MAX_SAMPLES)
        )
        not_enough_data = widget.Msg(
            "Too few ({}) unique data instances for {} clusters"
        )

    INIT_METHODS = (("Initialize with KMeans++", "k-means++"),
                    ("Random initialization", "random"))

    resizing_enabled = False
    buttons_area_orientation = Qt.Vertical

    k = Setting(3)
    k_from = Setting(2)
    k_to = Setting(8)
    optimize_k = Setting(False)
    max_iterations = Setting(300)
    n_init = Setting(10)
    smart_init = Setting(0)  # KMeans++
    auto_commit = Setting(True)

    settings_version = 2

    @classmethod
    def migrate_settings(cls, settings, version):
        # type: (Dict, int) -> None
        if version < 2:
            if 'auto_apply' in settings:
                settings['auto_commit'] = settings.get('auto_apply', True)
                settings.pop('auto_apply', None)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Table]
        self.clusterings = {}

        self.__executor = ThreadExecutor(parent=self)
        self.__task = None  # type: Optional[Task]

        layout = QGridLayout()
        bg = gui.radioButtonsInBox(
            self.controlArea, self, "optimize_k", orientation=layout,
            box="Number of Clusters", callback=self.update_method,
        )

        layout.addWidget(
            gui.appendRadioButton(bg, "Fixed:", addToLayout=False), 1, 1)
        sb = gui.hBox(None, margin=0)
        gui.spin(
            sb, self, "k", minv=2, maxv=30,
            controlWidth=60, alignment=Qt.AlignRight, callback=self.update_k)
        gui.rubber(sb)
        layout.addWidget(sb, 1, 2)

        layout.addWidget(
            gui.appendRadioButton(bg, "From", addToLayout=False), 2, 1)
        ftobox = gui.hBox(None)
        ftobox.layout().setContentsMargins(0, 0, 0, 0)
        layout.addWidget(ftobox, 2, 2)
        gui.spin(
            ftobox, self, "k_from", minv=2, maxv=29,
            controlWidth=60, alignment=Qt.AlignRight,
            callback=self.update_from)
        gui.widgetLabel(ftobox, "to")
        gui.spin(
            ftobox, self, "k_to", minv=3, maxv=30,
            controlWidth=60, alignment=Qt.AlignRight,
            callback=self.update_to)
        gui.rubber(ftobox)

        box = gui.vBox(self.controlArea, "Initialization")
        gui.comboBox(
            box, self, "smart_init", items=[m[0] for m in self.INIT_METHODS],
            callback=self.invalidate)

        layout = QGridLayout()
        gui.widgetBox(box, orientation=layout)
        layout.addWidget(gui.widgetLabel(None, "Re-runs: "), 0, 0, Qt.AlignLeft)
        sb = gui.hBox(None, margin=0)
        layout.addWidget(sb, 0, 1)
        gui.lineEdit(
            sb, self, "n_init", controlWidth=60,
            valueType=int, validator=QIntValidator(), callback=self.invalidate)
        layout.addWidget(
            gui.widgetLabel(None, "Maximum iterations: "), 1, 0, Qt.AlignLeft)
        sb = gui.hBox(None, margin=0)
        layout.addWidget(sb, 1, 1)
        gui.lineEdit(
            sb, self, "max_iterations", controlWidth=60, valueType=int,
            validator=QIntValidator(), callback=self.invalidate)

        self.apply_button = gui.auto_commit(
            self.buttonsArea, self, "auto_commit", "Apply", box=None,
            commit=self.commit)
        gui.rubber(self.controlArea)

        box = gui.vBox(self.mainArea, box="Silhouette Scores")
        self.mainArea.setVisible(self.optimize_k)
        self.table_model = ClusterTableModel(self)
        table = self.table_view = QTableView(self.mainArea)
        table.setModel(self.table_model)
        table.setSelectionMode(QTableView.SingleSelection)
        table.setSelectionBehavior(QTableView.SelectRows)
        table.setItemDelegate(gui.ColoredBarItemDelegate(self, color=Qt.cyan))
        table.selectionModel().selectionChanged.connect(self.select_row)
        table.setMaximumWidth(200)
        table.horizontalHeader().setStretchLastSection(True)
        table.horizontalHeader().hide()
        table.setShowGrid(False)
        box.layout().addWidget(table)

    def adjustSize(self):
        self.ensurePolished()
        s = self.sizeHint()
        self.resize(s)

    def update_method(self):
        self.table_model.clear_scores()
        self.commit()

    def update_k(self):
        self.optimize_k = False
        self.table_model.clear_scores()
        self.commit()

    def update_from(self):
        self.k_to = max(self.k_from + 1, self.k_to)
        self.optimize_k = True
        self.table_model.clear_scores()
        self.commit()

    def update_to(self):
        self.k_from = min(self.k_from, self.k_to - 1)
        self.optimize_k = True
        self.table_model.clear_scores()
        self.commit()

    def enough_data_instances(self, k):
        """k cannot be larger than the number of data instances."""
        return len(self.data) >= k

    @property
    def has_attributes(self):
        return len(self.data.domain.attributes)

    @staticmethod
    def _compute_clustering(data, k, init, n_init, max_iter, silhouette, random_state):
        # type: (Table, int, str, int, int, bool) -> KMeansModel
        if k > len(data):
            raise NotEnoughData()

        return KMeans(
            n_clusters=k, init=init, n_init=n_init, max_iter=max_iter,
            compute_silhouette_score=silhouette, random_state=random_state,
        )(data)

    @Slot(int, int)
    def __progress_changed(self, n, d):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        self.progressBarSet(100 * n / d)

    @Slot(int, Exception)
    def __on_exception(self, idx, ex):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None

        if isinstance(ex, NotEnoughData):
            self.Error.not_enough_data(len(self.data), self.k_from + idx)

        # Only show failed message if there is only 1 k to compute
        elif not self.optimize_k:
            self.Error.failed(str(ex))

        self.clusterings[self.k_from + idx] = str(ex)

    @Slot(int, object)
    def __clustering_complete(self, _, result):
        # type: (int, KMeansModel) -> None
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None

        self.clusterings[result.k] = result

    @Slot()
    def __commit_finished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        assert self.data is not None

        self.__task = None
        self.setBlocking(False)
        self.progressBarFinished()

        if self.optimize_k:
            self.update_results()

        if self.optimize_k and all(isinstance(self.clusterings[i], str)
                                   for i in range(self.k_from, self.k_to + 1)):
            # Show the error of the last clustering
            self.Error.failed(self.clusterings[self.k_to])

        self.send_data()

    def __launch_tasks(self, ks):
        # type: (List[int]) -> None
        """Execute clustering in separate threads for all given ks."""
        futures = [self.__executor.submit(
            self._compute_clustering,
            data=self.data,
            k=k,
            init=self.INIT_METHODS[self.smart_init][1],
            n_init=self.n_init,
            max_iter=self.max_iterations,
            silhouette=True,
            random_state=RANDOM_STATE,
        ) for k in ks]
        watcher = FutureSetWatcher(futures)
        watcher.resultReadyAt.connect(self.__clustering_complete)
        watcher.progressChanged.connect(self.__progress_changed)
        watcher.exceptionReadyAt.connect(self.__on_exception)
        watcher.doneAll.connect(self.__commit_finished)

        self.__task = Task(futures, watcher)
        self.progressBarInit(processEvents=False)
        self.setBlocking(True)

    def cancel(self):
        if self.__task is not None:
            task, self.__task = self.__task, None
            task.cancel()

            task.watcher.resultReadyAt.disconnect(self.__clustering_complete)
            task.watcher.progressChanged.disconnect(self.__progress_changed)
            task.watcher.exceptionReadyAt.disconnect(self.__on_exception)
            task.watcher.doneAll.disconnect(self.__commit_finished)

            self.progressBarFinished()
            self.setBlocking(False)

    def run_optimization(self):
        if not self.enough_data_instances(self.k_from):
            self.Error.not_enough_data(len(self.data), self.k_from)
            return

        if not self.enough_data_instances(self.k_to):
            self.Warning.not_enough_data(len(self.data), self.k_to)
            return

        needed_ks = [k for k in range(self.k_from, self.k_to + 1)
                     if k not in self.clusterings]

        if needed_ks:
            self.__launch_tasks(needed_ks)
        else:
            # If we don't need to recompute anything, just set the results to
            # what they were before
            self.update_results()

    def cluster(self):
        # Check if the k already has a computed clustering
        if self.k in self.clusterings:
            self.send_data()
            return

        # Check if there is enough data
        if not self.enough_data_instances(self.k):
            self.Error.not_enough_data(len(self.data), self.k)
            return

        self.__launch_tasks([self.k])

    def commit(self):
        self.cancel()
        self.clear_messages()

        # Some time may pass before the new scores are computed, so clear the
        # old scores to avoid potential confusion. Hiding the mainArea could
        # cause flickering when the clusters are computed quickly, so this is
        # the better alternative
        self.table_model.clear_scores()
        self.mainArea.setVisible(self.optimize_k and self.data is not None and
                                 self.has_attributes)

        if self.data is None:
            self.send_data()
            return

        if not self.has_attributes:
            self.Error.no_attributes()
            self.send_data()
            return

        if self.optimize_k:
            self.run_optimization()
        else:
            self.cluster()

        QTimer.singleShot(100, self.adjustSize)

    def invalidate(self):
        self.cancel()
        self.Error.clear()
        self.Warning.clear()
        self.clusterings = {}
        self.table_model.clear_scores()

        self.commit()

    def update_results(self):
        scores = [
            mk if isinstance(mk, str) else mk.silhouette for mk in (
                self.clusterings[k] for k in range(self.k_from, self.k_to + 1))
        ]
        best_row = max(
            range(len(scores)), default=0,
            key=lambda x: 0 if isinstance(scores[x], str) else scores[x]
        )
        self.table_model.set_scores(scores, self.k_from)
        self.table_view.selectRow(best_row)
        self.table_view.setFocus(Qt.OtherFocusReason)
        self.table_view.resizeRowsToContents()

    def selected_row(self):
        indices = self.table_view.selectedIndexes()
        if indices:
            return indices[0].row()

    def select_row(self):
        self.send_data()

    def send_data(self):
        if self.optimize_k:
            row = self.selected_row()
            k = self.k_from + row if row is not None else None
        else:
            k = self.k

        km = self.clusterings.get(k)
        if self.data is None or km is None or isinstance(km, str):
            self.Outputs.annotated_data.send(None)
            self.Outputs.centroids.send(None)
            return

        domain = self.data.domain
        cluster_var = DiscreteVariable(
            get_unique_names(domain, "Cluster"),
            values=["C%d" % (x + 1) for x in range(km.k)]
        )
        clust_ids = km(self.data)
        silhouette_var = ContinuousVariable(
            get_unique_names(domain, "Silhouette"))
        if km.silhouette_samples is not None:
            self.Warning.no_silhouettes.clear()
            scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
        else:
            self.Warning.no_silhouettes()
            scores = np.nan

        new_domain = add_columns(domain, metas=[cluster_var, silhouette_var])
        new_table = self.data.transform(new_domain)
        new_table.get_column_view(cluster_var)[0][:] = clust_ids.X.ravel()
        new_table.get_column_view(silhouette_var)[0][:] = scores

        centroids = Table(Domain(km.pre_domain.attributes), km.centroids)

        self.Outputs.annotated_data.send(new_table)
        self.Outputs.centroids.send(centroids)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.data, old_data = data, self.data

        # Do not needlessly recluster the data if X hasn't changed
        if old_data and self.data and np.array_equal(self.data.X, old_data.X):
            if self.auto_commit:
                self.send_data()
        else:
            self.invalidate()

    def send_report(self):
        # False positives (Setting is not recognized as int)
        # pylint: disable=invalid-sequence-index
        if self.optimize_k and self.selected_row() is not None:
            k_clusters = self.k_from + self.selected_row()
        else:
            k_clusters = self.k
        init_method = self.INIT_METHODS[self.smart_init][0]
        init_method = init_method[0].lower() + init_method[1:]
        self.report_items((
            ("Number of clusters", k_clusters),
            ("Optimization", "{}, {} re-runs limited to {} steps".format(
                init_method, self.n_init, self.max_iterations))))
        if self.data is not None:
            self.report_data("Data", self.data)
            if self.optimize_k:
                self.report_table(
                    "Silhouette scores for different numbers of clusters",
                    self.table_view)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
class OWDatabasesUpdate(OWWidget):

    name = "Databases Update"
    description = "Update local systems biology databases."
    icon = "../widgets/icons/OWDatabasesUpdate.svg"
    priority = 1

    inputs = []
    outputs = []

    want_main_area = False

    def __init__(self, parent=None, signalManager=None,
                 name="Databases update"):
        OWWidget.__init__(self, parent, signalManager, name,
                          wantMainArea=False)

        self.searchString = ""

        fbox = gui.widgetBox(self.controlArea, "Filter")
        self.completer = TokenListCompleter(
            self, caseSensitivity=Qt.CaseInsensitive)
        self.lineEditFilter = QLineEdit(textChanged=self.SearchUpdate)
        self.lineEditFilter.setCompleter(self.completer)

        fbox.layout().addWidget(self.lineEditFilter)

        box = gui.widgetBox(self.controlArea, "Files")

        self.filesView = QTreeWidget(self)
        self.filesView.setHeaderLabels(
            ["", "Data Source", "Update", "Last Updated", "Size"])

        self.filesView.setRootIsDecorated(False)
        self.filesView.setUniformRowHeights(True)
        self.filesView.setSelectionMode(QAbstractItemView.NoSelection)
        self.filesView.setSortingEnabled(True)
        self.filesView.sortItems(1, Qt.AscendingOrder)
        self.filesView.setItemDelegateForColumn(
            0, UpdateOptionsItemDelegate(self.filesView))

        self.filesView.model().layoutChanged.connect(self.SearchUpdate)

        box.layout().addWidget(self.filesView)

        box = gui.widgetBox(self.controlArea, orientation="horizontal")
        self.updateButton = gui.button(
            box, self, "Update all",
            callback=self.UpdateAll,
            tooltip="Update all updatable files",
         )

        self.downloadButton = gui.button(
            box, self, "Download all",
            callback=self.DownloadFiltered,
            tooltip="Download all filtered files shown"
        )

        self.cancelButton = gui.button(
            box, self, "Cancel", callback=self.Cancel,
            tooltip="Cancel scheduled downloads/updates."
        )

        self.retryButton = gui.button(
            box, self, "Reconnect", callback=self.RetrieveFilesList
        )
        self.retryButton.hide()

        gui.rubber(box)
        self.warning(0)

        box = gui.widgetBox(self.controlArea, orientation="horizontal")
        gui.rubber(box)

        self.infoLabel = QLabel()
        self.infoLabel.setAlignment(Qt.AlignCenter)

        self.controlArea.layout().addWidget(self.infoLabel)
        self.infoLabel.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)

        self.updateItems = []

        self.resize(800, 600)

        self.progress = ProgressState(self, maximum=3)
        self.progress.valueChanged.connect(self._updateProgress)
        self.progress.rangeChanged.connect(self._updateProgress)
        self.executor = ThreadExecutor(
            threadPool=QThreadPool(maxThreadCount=2)
        )

        task = Task(self, function=self.RetrieveFilesList)
        task.exceptionReady.connect(self.HandleError)
        task.start()

        self._tasks = []
        self._haveProgress = False

    def RetrieveFilesList(self):
        self.retryButton.hide()
        self.warning(0)
        self.progress.setRange(0, 3)

        task = Task(function=partial(retrieveFilesList, methodinvoke(self.progress, "advance")))

        task.resultReady.connect(self.SetFilesList)
        task.exceptionReady.connect(self.HandleError)

        self.executor.submit(task)

        self.setEnabled(False)

    def SetFilesList(self, serverInfo):
        """
        Set the files to show.
        """
        self.setEnabled(True)

        localInfo = serverfiles.allinfo()
        all_tags = set()

        self.filesView.clear()
        self.updateItems = []

        for item in join_info_dict(localInfo, serverInfo):
            tree_item = UpdateTreeWidgetItem(item)
            options_widget = UpdateOptionsWidget(item.state)
            options_widget.item = item

            options_widget.installClicked.connect(
                partial(self.SubmitDownloadTask, item.domain, item.filename)
            )
            options_widget.removeClicked.connect(
                partial(self.SubmitRemoveTask, item.domain, item.filename)
            )

            self.updateItems.append((item, tree_item, options_widget))
            all_tags.update(item.tags)

        self.filesView.addTopLevelItems(
            [tree_item for _, tree_item, _ in self.updateItems]
        )

        for item, tree_item, options_widget in self.updateItems:
            self.filesView.setItemWidget(tree_item, 0, options_widget)

            # Add an update button if the file is updateable
            if item.state == OUTDATED:
                button = QToolButton(
                    None, text="Update",
                    maximumWidth=120,
                    minimumHeight=20,
                    maximumHeight=20
                )

                if sys.platform == "darwin":
                    button.setAttribute(Qt.WA_MacSmallSize)

                button.clicked.connect(
                    partial(self.SubmitDownloadTask, item.domain,
                            item.filename)
                )

                self.filesView.setItemWidget(tree_item, 2, button)

        self.progress.advance()

        self.filesView.setColumnWidth(0, self.filesView.sizeHintForColumn(0))

        for column in range(1, 4):
            contents_hint = self.filesView.sizeHintForColumn(column)
            header_hint = self.filesView.header().sectionSizeHint(column)
            width = max(min(contents_hint, 400), header_hint)
            self.filesView.setColumnWidth(column, width)

        hints = [hint for hint in sorted(all_tags) if not hint.startswith("#")]
        self.completer.setTokenList(hints)
        self.SearchUpdate()
        self.UpdateInfoLabel()
        self.toggleButtons()
        self.cancelButton.setEnabled(False)

        self.progress.setRange(0, 0)

    def buttonCheck(self, selected_items, state, button):
        for item in selected_items:
            if item.state != state:
                button.setEnabled(False)
            else:
                button.setEnabled(True)
                break

    def toggleButtons(self):
        selected_items = [item for item, tree_item, _ in self.updateItems if not tree_item.isHidden()]
        self.buttonCheck(selected_items, OUTDATED, self.updateButton)
        self.buttonCheck(selected_items, AVAILABLE, self.downloadButton)

    def HandleError(self, exception):
        if isinstance(exception, ConnectionError):
            self.warning(0,
                       "Could not connect to server! Check your connection "
                       "and try to reconnect.")
            self.SetFilesList({})
            self.retryButton.show()
        else:
            sys.excepthook(type(exception), exception, None)
            self.progress.setRange(0, 0)
            self.setEnabled(True)

    def UpdateInfoLabel(self):
        local = [item for item, tree_item, _ in self.updateItems
                 if item.state != AVAILABLE and not tree_item.isHidden()]
        size = sum(float(item.size) for item in local)

        onServer = [item for item, tree_item, _ in self.updateItems if not tree_item.isHidden()]
        sizeOnServer = sum(float(item.size) for item in onServer)

        text = ("%i items, %s (on server: %i items, %s)" %
                (len(local),
                 serverfiles.sizeformat(size),
                 len(onServer),
                 serverfiles.sizeformat(sizeOnServer)))

        self.infoLabel.setText(text)

    def UpdateAll(self):
        self.warning(0)
        for item, tree_item, _ in self.updateItems:
            if item.state == OUTDATED and not tree_item.isHidden():
                self.SubmitDownloadTask(item.domain, item.filename)

    def DownloadFiltered(self):
        # TODO: submit items in the order shown.
        for item, tree_item, _ in self.updateItems:
            if not tree_item.isHidden() and item.state in \
                    [AVAILABLE, OUTDATED]:
                self.SubmitDownloadTask(item.domain, item.filename)

    def SearchUpdate(self, searchString=None):
        strings = str(self.lineEditFilter.text()).split()
        for item, tree_item, _ in self.updateItems:
            hide = not all(UpdateItem_match(item, string)
                           for string in strings)
            tree_item.setHidden(hide)
        self.UpdateInfoLabel()
        self.toggleButtons()

    def SubmitDownloadTask(self, domain, filename):
        """
        Submit the (domain, filename) to be downloaded/updated.
        """
        self.cancelButton.setEnabled(True)

        index = self.updateItemIndex(domain, filename)
        _, tree_item, opt_widget = self.updateItems[index]

        task = DownloadTask(domain, filename, serverfiles.LOCALFILES)

        self.progress.adjustRange(0, 100)

        pb = ItemProgressBar(self.filesView)
        pb.setRange(0, 100)
        pb.setTextVisible(False)

        task.advanced.connect(pb.advance)
        task.advanced.connect(self.progress.advance)
        task.finished.connect(pb.hide)
        task.finished.connect(self.onDownloadFinished, Qt.QueuedConnection)
        task.exception.connect(self.onDownloadError, Qt.QueuedConnection)

        self.filesView.setItemWidget(tree_item, 2, pb)

        # Clear the text so it does not show behind the progress bar.
        tree_item.setData(2, Qt.DisplayRole, "")
        pb.show()

        # Disable the options widget
        opt_widget.setEnabled(False)
        self._tasks.append(task)

        self.executor.submit(task)

    def EndDownloadTask(self, task):
        future = task.future()
        index = self.updateItemIndex(task.domain, task.filename)
        item, tree_item, opt_widget = self.updateItems[index]

        self.filesView.removeItemWidget(tree_item, 2)
        opt_widget.setEnabled(True)

        if future.cancelled():
            # Restore the previous state
            tree_item.setUpdateItem(item)
            opt_widget.setState(item.state)

        elif future.exception():
            tree_item.setUpdateItem(item)
            opt_widget.setState(item.state)

            # Show the exception string in the size column.
            self.warning(0, "Error while downloading. Check your connection "
                            "and retry.")

            # recreate button for download
            button = QToolButton(
                None, text="Retry",
                maximumWidth=120,
                minimumHeight=20,
                maximumHeight=20
            )

            if sys.platform == "darwin":
                button.setAttribute(Qt.WA_MacSmallSize)

            button.clicked.connect(
                partial(self.SubmitDownloadTask, item.domain,
                        item.filename)
            )

            self.filesView.setItemWidget(tree_item, 2, button)

        else:
            # get the new updated info dict and replace the the old item
            self.warning(0)
            info = serverfiles.info(item.domain, item.filename)
            new_item = update_item_from_info(item.domain, item.filename,
                                             info, info)

            self.updateItems[index] = (new_item, tree_item, opt_widget)

            tree_item.setUpdateItem(new_item)
            opt_widget.setState(new_item.state)

            self.UpdateInfoLabel()

    def SubmitRemoveTask(self, domain, filename):
        serverfiles.LOCALFILES.remove(domain, filename)
        index = self.updateItemIndex(domain, filename)
        item, tree_item, opt_widget = self.updateItems[index]

        if item.info_server:
            new_item = item._replace(state=AVAILABLE, local=None,
                                      info_local=None)
        else:
            new_item = item._replace(local=None, info_local=None)
            # Disable the options widget. No more actions can be performed
            # for the item.
            opt_widget.setEnabled(False)

        tree_item.setUpdateItem(new_item)
        opt_widget.setState(new_item.state)
        self.updateItems[index] = (new_item, tree_item, opt_widget)

        self.UpdateInfoLabel()

    def Cancel(self):
        """
        Cancel all pending update/download tasks (that have not yet started).
        """
        for task in self._tasks:
            task.future().cancel()

    def onDeleteWidget(self):
        self.Cancel()
        self.executor.shutdown(wait=False)
        OWWidget.onDeleteWidget(self)

    def onDownloadFinished(self):
        # on download completed/canceled/error
        assert QThread.currentThread() is self.thread()
        for task in list(self._tasks):
            future = task.future()
            if future.done():
                self.EndDownloadTask(task)
                self._tasks.remove(task)

        if not self._tasks:
            # Clear/reset the overall progress
            self.progress.setRange(0, 0)
            self.cancelButton.setEnabled(False)

    def onDownloadError(self, exc_info):
        sys.excepthook(*exc_info)
        self.warning(0, "Error while downloading. Check your connection and "
                        "retry.")

    def updateItemIndex(self, domain, filename):
        for i, (item, _, _) in enumerate(self.updateItems):
            if item.domain == domain and item.filename == filename:
                return i
        raise ValueError("%r, %r not in update list" % (domain, filename))

    def _updateProgress(self, *args):
        rmin, rmax = self.progress.range()
        if rmin != rmax:
            if not self._haveProgress:
                self._haveProgress = True
                self.progressBarInit()

            self.progressBarSet(self.progress.ratioCompleted() * 100,
                                processEvents=None)
        if rmin == rmax:
            self._haveProgress = False
            self.progressBarFinished()
Ejemplo n.º 25
0
class OWResolwetSNE(OWWidget):
    name = "t-SNE"
    description = "Two-dimensional data projection with t-SNE."
    icon = "icons/OWResolwetSNE.svg"
    priority = 50

    class Inputs:
        data = Input("Data", resolwe.Data, default=True)

    class Outputs:
        selected_data = Output("Selected Data", resolwe.Data, default=True)

    settings_version = 2

    #: Runtime state
    Running, Finished, Waiting = 1, 2, 3

    settingsHandler = settings.DomainContextHandler()

    max_iter = settings.Setting(300)
    perplexity = settings.Setting(30)
    pca_components = settings.Setting(20)

    # output embedding role.
    NoRole, AttrRole, AddAttrRole, MetaRole = 0, 1, 2, 3

    auto_commit = settings.Setting(True)

    selection_indices = settings.Setting(None, schema_only=True)

    legend_anchor = settings.Setting(((1, 0), (1, 0)))

    graph = SettingProvider(OWMDSGraph)

    jitter_sizes = [0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10]

    graph_name = "graph.plot_widget.plotItem"

    class Error(OWWidget.Error):
        not_enough_rows = Msg("Input data needs at least 2 rows")
        constant_data = Msg("Input data is constant")
        no_attributes = Msg("Data has no attributes")
        out_of_memory = Msg("Out of memory")
        optimization_error = Msg("Error during optimization\n{}")

    def __init__(self):
        super().__init__()
        #: Effective data used for plot styling/annotations.
        self.data = None  # type: Optional[Orange.data.Table]
        #: Input subset data table
        self.subset_data = None  # type: Optional[Orange.data.Table]
        #: Input data table
        self.signal_data = None

        # resolwe variables
        self.data_table_object = None  # type: Optional[resolwe.Data]

        self._tsne_slug = 't-sne'
        self._tsne_selection_slug = 't-sne-selection'
        self._embedding_data_object = None
        self._embedding = None
        self._embedding_clas_var = None
        self.variable_x = ContinuousVariable("tsne-x")
        self.variable_y = ContinuousVariable("tsne-y")

        # threading
        self._task = None  # type: Optional[ResolweTask]
        self._executor = ThreadExecutor()

        self.res = ResolweHelper()

        self._subset_mask = None  # type: Optional[np.ndarray]
        self._invalidated = False
        self.pca_data = None
        self._curve = None
        self._data_metas = None

        self.variable_x = ContinuousVariable("tsne-x")
        self.variable_y = ContinuousVariable("tsne-y")

        self.__update_loop = None

        self.__in_next_step = False
        self.__draw_similar_pairs = False

        box = gui.vBox(self.controlArea, "t-SNE")
        form = QFormLayout(labelAlignment=Qt.AlignLeft,
                           formAlignment=Qt.AlignLeft,
                           fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
                           verticalSpacing=10)

        form.addRow("Max iterations:",
                    gui.spin(box, self, "max_iter", 250, 2000, step=50))

        form.addRow("Perplexity:",
                    gui.spin(box, self, "perplexity", 1, 100, step=1))

        box.layout().addLayout(form)

        gui.separator(box, 10)
        self.runbutton = gui.button(box,
                                    self,
                                    "Run",
                                    callback=self._run_embeding)

        box = gui.vBox(self.controlArea, "PCA Preprocessing")
        gui.hSlider(box,
                    self,
                    'pca_components',
                    label="Components: ",
                    minValue=2,
                    maxValue=50,
                    step=1)  #, callback=self._initialize)

        box = gui.vBox(self.mainArea, True, margin=0)
        self.graph = OWMDSGraph(self,
                                box,
                                "MDSGraph",
                                view_box=MDSInteractiveViewBox)
        box.layout().addWidget(self.graph.plot_widget)
        self.plot = self.graph.plot_widget

        g = self.graph.gui
        box = g.point_properties_box(self.controlArea)
        self.models = g.points_models
        # Because sc data frequently has many genes,
        # showing all attributes in combo boxes can cause problems
        # QUICKFIX: Remove a separator and attributes from order
        # (leaving just the class and metas)
        for model in self.models:
            model.order = model.order[:-2]

        g.add_widgets(ids=[g.JitterSizeSlider], widget=box)

        box = gui.vBox(self.controlArea, "Plot Properties")
        g.add_widgets([
            g.ShowLegend, g.ToolTipShowsAll, g.ClassDensity,
            g.LabelOnlySelected
        ], box)

        self.controlArea.layout().addStretch(100)
        self.icons = gui.attributeIconDict

        palette = self.graph.plot_widget.palette()
        self.graph.set_palette(palette)

        gui.rubber(self.controlArea)

        self.graph.box_zoom_select(self.controlArea)

        gui.auto_commit(self.controlArea, self, "auto_commit",
                        "Send Selection", "Send Automatically")

        self.plot.getPlotItem().hideButtons()
        self.plot.setRenderHint(QPainter.Antialiasing)

        self.graph.jitter_continuous = True
        # self._initialize()

    def update_colors(self):
        pass

    def update_density(self):
        self.update_graph(reset_view=False)

    def update_regression_line(self):
        self.update_graph(reset_view=False)

    def prepare_data(self):
        pass

    def update_graph(self, reset_view=True, **_):
        self.graph.zoomStack = []
        if self.graph.data is None:
            return
        self.graph.update_data(self.variable_x,
                               self.variable_y,
                               reset_view=True)

    def reset_graph_data(self, *_):
        if self.data is not None:
            self.graph.rescale_data()
            self.update_graph()

    def selection_changed(self):
        if self._task:
            self.cancel(clear_state=False)
            self._task = None
            self._executor = ThreadExecutor()

        self.commit()

    def _clear_plot(self):
        self.graph.plot_widget.clear()

    def _clear_state(self):
        self._clear_plot()
        self.graph.new_data(None)
        self._embedding_data_object = None
        self._embedding = None
        self._embedding_clas_var = None
        self._task = None
        self._executor = ThreadExecutor()

    def cancel(self, clear_state=True):
        """Cancel the current task (if any)."""

        if self._task is not None:
            self._executor.shutdown(wait=False)
            self.runbutton.setText('Run')
            self.progressBarFinished()
            if clear_state:
                self._clear_state()

    def run_task(self, slug, func):
        if self._task is not None:
            try:
                self.cancel()
            except CancelledError as e:
                print(e)
        assert self._task is None

        self.progressBarInit()

        self._task = ResolweTask(slug)
        self._task.future = self._executor.submit(func)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.finished.connect(self.task_finished)

    @Slot(Future, name='Finished')
    def task_finished(self, future):
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is future
        assert future.done()

        try:
            future_result = future.result()
        except Exception as ex:
            # TODO: raise exceptions
            raise ex
        else:
            if self._task.slug == self._tsne_slug:
                self._embedding_data_object = future_result
                self._embedding_clas_var = self.res.get_json(
                    self._embedding_data_object, 'class_var')
                self._embedding = np.array(
                    self.res.get_json(self._embedding_data_object,
                                      'embedding_json', 'embedding'))
                self._setup_plot()

            if self._task.slug == self._tsne_selection_slug:
                print(future_result)
                self.Outputs.selected_data.send(future_result)

        finally:
            self.progressBarFinished()
            self.runbutton.setText('Start')
            self._task = None

    @Inputs.data
    def set_data(self, data):
        # type: (Optional[resolwe.Data]) -> None
        if data:
            self.data_table_object = data
            self._run_embeding()

    def _run_embeding(self):
        if self._task:
            self.cancel()
            return

        if self._task is None:
            inputs = {
                'data_table': self.data_table_object,
                'pca_components': self.pca_components,
                'perplexity': self.perplexity,
                'iterations': self.max_iter
            }
            if self._embedding is not None and self._embedding_data_object is not None:
                inputs['init'] = self._embedding_data_object

            func = partial(self.res.run_process, self._tsne_slug, **inputs)

            # move filter process in thread
            self.run_task(self._tsne_slug, func)
            self.runbutton.setText('Stop')

    def _setup_plot(self):
        class_var = DiscreteVariable(self._embedding_clas_var['name'],
                                     values=self._embedding_clas_var['values'])
        y_data = self._embedding_clas_var['y_data']
        data = np.c_[self._embedding, y_data]

        plot_data = Table(
            Domain([self.variable_x, self.variable_y], class_vars=class_var),
            data)

        domain = plot_data and len(plot_data) and plot_data.domain or None
        for model in self.models:
            model.set_domain(domain)
        self.graph.attr_color = plot_data.domain.class_var if domain else None
        self.graph.attr_shape = None
        self.graph.attr_size = None
        self.graph.attr_label = None

        self.graph.new_data(plot_data)
        self.graph.update_data(self.variable_x, self.variable_y, True)

    def commit(self):
        selection = self.graph.get_selection()
        if self._embedding_data_object is not None and selection is not None:
            inputs = {
                'data_table': self.data_table_object,
                'embedding': self._embedding_data_object,
                'selection': selection.tolist(),
                'x_tsne_var': self.variable_x.name,
                'y_tsne_var': self.variable_y.name
            }

            func = partial(self.res.run_process, self._tsne_selection_slug,
                           **inputs)

            self.run_task(self._tsne_selection_slug, func)

        self.Outputs.selected_data.send(None)

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._clear_plot()
        self._clear_state()

    def send_report(self):
        if self.data is None:
            return

        def name(var):
            return var and var.name

        caption = report.render_items_vert(
            (("Color", name(self.graph.attr_color)),
             ("Label", name(self.graph.attr_label)),
             ("Shape", name(self.graph.attr_shape)),
             ("Size", name(self.graph.attr_size)),
             ("Jittering", self.graph.jitter_size != 0
              and "{} %".format(self.graph.jitter_size))))
        self.report_plot()
        if caption:
            self.report_caption(caption)
class OWGOBrowser(widget.OWWidget):
    name = "GO Browser"
    description = "Enrichment analysis for Gene Ontology terms."
    icon = "../widgets/icons/OWGOBrowser.svg"
    priority = 7

    inputs = [("Cluster Data", Orange.data.Table,
               "setDataset", widget.Single + widget.Default),
              ("Reference Data", Orange.data.Table,
               "setReferenceDataset")]

    outputs = [("Data on Selected Genes", Orange.data.Table),
               ("Enrichment Report", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    geneAttrIndex = settings.ContextSetting(0)
    useAttrNames = settings.ContextSetting(False)
    useReferenceDataset = settings.Setting(False)
    aspectIndex = settings.Setting(0)

    useEvidenceType = settings.Setting(
        {et: True for et in go.evidenceTypesOrdered})

    filterByNumOfInstances = settings.Setting(False)
    minNumOfInstances = settings.Setting(1)
    filterByPValue = settings.Setting(True)
    maxPValue = settings.Setting(0.2)
    filterByPValue_nofdr = settings.Setting(False)
    maxPValue_nofdr = settings.Setting(0.01)
    probFunc = settings.Setting(0)

    selectionDirectAnnotation = settings.Setting(0)
    selectionDisjoint = settings.Setting(0)

    class Error(widget.OWWidget.Error):
        serverfiles_unavailable = widget.Msg('Can not locate annotation files, '
                                             'please check your connection and try again.')
        missing_annotation = widget.Msg(ERROR_ON_MISSING_ANNOTATION)
        missing_gene_id = widget.Msg(ERROR_ON_MISSING_GENE_ID)
        missing_tax_id = widget.Msg(ERROR_ON_MISSING_TAX_ID)

    def __init__(self, parent=None):
        super().__init__(self, parent)

        self.input_data = None
        self.ref_data = None
        self.ontology = None
        self.annotations = None
        self.loaded_annotation_code = None
        self.treeStructRootKey = None
        self.probFunctions = [statistics.Binomial(), statistics.Hypergeometric()]
        self.selectedTerms = []

        self.selectionChanging = 0
        self.__state = State.Ready
        self.__scheduletimer = QTimer(self, singleShot=True)
        self.__scheduletimer.timeout.connect(self.__update)

        #############
        # GUI
        #############
        self.tabs = gui.tabWidget(self.controlArea)
        # Input tab
        self.inputTab = gui.createTabPage(self.tabs, "Input")
        box = gui.widgetBox(self.inputTab, "Info")
        self.infoLabel = gui.widgetLabel(box, "No data on input\n")

        gui.button(box, self, "Ontology/Annotation Info",
                   callback=self.ShowInfo,
                   tooltip="Show information on loaded ontology and annotations")

        self.referenceRadioBox = gui.radioButtonsInBox(
            self.inputTab, self, "useReferenceDataset",
            ["Entire genome", "Reference set (input)"],
            tooltips=["Use entire genome for reference",
                      "Use genes from Referece Examples input signal as reference"],
            box="Reference", callback=self.__invalidate)

        self.referenceRadioBox.buttons[1].setDisabled(True)
        gui.radioButtonsInBox(
            self.inputTab, self, "aspectIndex",
            ["Biological process", "Cellular component", "Molecular function"],
            box="Aspect", callback=self.__invalidate)

        # Filter tab
        self.filterTab = gui.createTabPage(self.tabs, "Filter")
        box = gui.widgetBox(self.filterTab, "Filter GO Term Nodes")
        gui.checkBox(box, self, "filterByNumOfInstances", "Genes",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by number of input genes mapped to a term")
        ibox = gui.indentedBox(box)
        gui.spin(ibox, self, 'minNumOfInstances', 1, 100,
                 step=1, label='#:', labelWidth=15,
                 callback=self.FilterAndDisplayGraph,
                 callbackOnReturn=True,
                 tooltip="Min. number of input genes mapped to a term")

        gui.checkBox(box, self, "filterByPValue_nofdr", "p-value",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term p-value")

        gui.doubleSpin(gui.indentedBox(box), self, 'maxPValue_nofdr', 1e-8, 1,
                       step=1e-8,  label='p:', labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        # use filterByPValue for FDR, as it was the default in prior versions
        gui.checkBox(box, self, "filterByPValue", "FDR",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term FDR")
        gui.doubleSpin(gui.indentedBox(box), self, 'maxPValue', 1e-8, 1,
                       step=1e-8,  label='p:', labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        box = gui.widgetBox(box, "Significance test")

        gui.radioButtonsInBox(box, self, "probFunc", ["Binomial", "Hypergeometric"],
                              tooltips=["Use binomial distribution test",
                                        "Use hypergeometric distribution test"],
                              callback=self.__invalidate)  # TODO: only update the p values
        box = gui.widgetBox(self.filterTab, "Evidence codes in annotation",
                              addSpace=True)
        self.evidenceCheckBoxDict = {}
        for etype in go.evidenceTypesOrdered:
            ecb = QCheckBox(
                etype, toolTip=go.evidenceTypes[etype],
                checked=self.useEvidenceType[etype])
            ecb.toggled.connect(self.__on_evidenceChanged)
            box.layout().addWidget(ecb)
            self.evidenceCheckBoxDict[etype] = ecb

        # Select tab
        self.selectTab = gui.createTabPage(self.tabs, "Select")
        box = gui.radioButtonsInBox(
            self.selectTab, self, "selectionDirectAnnotation",
            ["Directly or Indirectly", "Directly"],
            box="Annotated genes",
            callback=self.ExampleSelection)

        box = gui.widgetBox(self.selectTab, "Output", addSpace=True)
        gui.radioButtonsInBox(
            box, self, "selectionDisjoint",
            btnLabels=["All selected genes",
                       "Term-specific genes",
                       "Common term genes"],
            tooltips=["Outputs genes annotated to all selected GO terms",
                      "Outputs genes that appear in only one of selected GO terms",
                      "Outputs genes common to all selected GO terms"],
            callback=self.ExampleSelection)

        # ListView for DAG, and table for significant GOIDs
        self.DAGcolumns = ['GO term', 'Cluster', 'Reference', 'p-value',
                           'FDR', 'Genes', 'Enrichment']

        self.splitter = QSplitter(Qt.Vertical, self.mainArea)
        self.mainArea.layout().addWidget(self.splitter)

        # list view
        self.listView = GOTreeWidget(self.splitter)
        self.listView.setSelectionMode(QTreeView.ExtendedSelection)
        self.listView.setAllColumnsShowFocus(1)
        self.listView.setColumnCount(len(self.DAGcolumns))
        self.listView.setHeaderLabels(self.DAGcolumns)

        self.listView.header().setSectionsClickable(True)
        self.listView.header().setSortIndicatorShown(True)
        self.listView.header().setSortIndicator(self.DAGcolumns.index('p-value'), Qt.AscendingOrder)
        self.listView.setSortingEnabled(True)
        self.listView.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))
        self.listView.setRootIsDecorated(True)

        self.listView.itemSelectionChanged.connect(self.ViewSelectionChanged)

        # table of significant GO terms
        self.sigTerms = QTreeWidget(self.splitter)
        self.sigTerms.setColumnCount(len(self.DAGcolumns))
        self.sigTerms.setHeaderLabels(self.DAGcolumns)
        self.sigTerms.setSortingEnabled(True)
        self.sigTerms.setSelectionMode(QTreeView.ExtendedSelection)
        self.sigTerms.header().setSortIndicator(self.DAGcolumns.index('p-value'), Qt.AscendingOrder)
        self.sigTerms.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))

        self.sigTerms.itemSelectionChanged.connect(self.TableSelectionChanged)

        self.sigTableTermsSorted = []
        self.graph = {}
        self.originalGraph = None

        self.inputTab.layout().addStretch(1)
        self.filterTab.layout().addStretch(1)
        self.selectTab.layout().addStretch(1)

        class AnnotationSlot(SimpleNamespace):
            taxid = ...  # type: str
            name = ...   # type: str
            filename = ...  # type:str

            @staticmethod
            def parse_tax_id(f_name):
                return f_name.split('.')[1]

        try:
            remote_files = serverfiles.ServerFiles().listfiles(DOMAIN)
        except (ConnectTimeout, RequestException, ConnectionError):
            # TODO: Warn user about failed connection to the remote server
            remote_files = []

        self.available_annotations = [
            AnnotationSlot(
                taxid=AnnotationSlot.parse_tax_id(annotation_file),
                name=taxonomy.common_taxid_to_name(AnnotationSlot.parse_tax_id(annotation_file)),
                filename=FILENAME_ANNOTATION.format(AnnotationSlot.parse_tax_id(annotation_file))
            )
            for _, annotation_file in set(remote_files + serverfiles.listfiles(DOMAIN))
            if annotation_file != FILENAME_ONTOLOGY

        ]
        self._executor = ThreadExecutor()

    def sizeHint(self):
        return QSize(1000, 700)

    def __on_evidenceChanged(self):
        for etype, cb in self.evidenceCheckBoxDict.items():
            self.useEvidenceType[etype] = cb.isChecked()
        self.__invalidate()

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.warning(0)
        self.warning(1)
        self.ClearGraph()

        self.send("Data on Selected Genes", None)
        self.send("Enrichment Report", None)

    def setDataset(self, data=None):
        self.closeContext()
        self.clear()
        self.Error.clear()
        if data:
            self.input_data = data
            self.tax_id = str(self.input_data.attributes.get(TAX_ID, None))
            self.use_attr_names = self.input_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.gene_id_attribute = self.input_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.gene_id_column = self.input_data.attributes.get(GENE_ID_COLUMN, None)
            self.annotation_index = None

            if not(self.use_attr_names is not None
                   and ((self.gene_id_attribute is None) ^ (self.gene_id_column is None))):

                if self.tax_id is None:
                    self.Error.missing_annotation()
                    return

                self.Error.missing_gene_id()
                return

            elif self.tax_id is None:
                self.Error.missing_tax_id()
                return

            _c2i = {a.taxid: i for i, a in enumerate(self.available_annotations)}
            try:
                self.annotation_index = _c2i[self.tax_id]
            except KeyError:
                self.Error.serverfiles_unavailable()
                # raise ValueError('Taxonomy {} not supported.'.format(self.tax_id))
                return

            self.__invalidate()

    def setReferenceDataset(self, data=None):
        self.Error.clear()
        if data:
            self.ref_data = data
            self.ref_tax_id = str(self.ref_data.attributes.get(TAX_ID, None))
            self.ref_use_attr_names = self.ref_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.ref_gene_id_attribute = self.ref_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.ref_gene_id_column = self.ref_data.attributes.get(GENE_ID_COLUMN, None)

            if not (self.ref_use_attr_names is not None
                    and ((self.ref_gene_id_attribute is None) ^ (self.ref_gene_id_column is None))):

                if self.ref_tax_id is None:
                    self.Error.missing_annotation()
                    return

                self.Error.missing_gene_id()
                return

            elif self.ref_tax_id is None:
                self.Error.missing_tax_id()
                return

        self.referenceRadioBox.buttons[1].setDisabled(not bool(data))
        self.referenceRadioBox.buttons[1].setText("Reference set")
        if self.input_data is not None and self.useReferenceDataset:
            self.useReferenceDataset = 0 if not data else 1
            self.__invalidate()

    @Slot()
    def __invalidate(self):
        # Invalidate the current results or pending task and schedule an
        # update.
        self.__scheduletimer.start()
        if self.__state != State.Ready:
            self.__state |= State.Stale

        self.SetGraph({})
        self.ref_genes = None
        self.input_genes = None

    def __invalidateAnnotations(self):
        self.annotations = None
        self.loaded_annotation_code = None
        if self.input_data:
            self.infoLabel.setText("...\n")
        self.__invalidate()

    @Slot()
    def __update(self):
        self.__scheduletimer.stop()
        if self.input_data is None:
            return

        if self.__state & State.Running:
            self.__state |= State.Stale
        elif self.__state & State.Downloading:
            self.__state |= State.Stale
        elif self.__state & State.Ready:
            if self.__ensure_data():
                self.Load()
                self.Enrichment()
            else:
                assert self.__state & State.Downloading
                assert self.isBlocking()

    def __get_ref_genes(self):
        self.ref_genes = []

        if self.ref_use_attr_names:
            for variable in self.input_data.domain.attributes:
                self.ref_genes.append(str(variable.attributes.get(self.ref_gene_id_attribute, '?')))
        else:
            genes, _ = self.ref_data.get_column_view(self.ref_gene_id_column)
            self.ref_genes = [str(g) for g in genes]

    def __get_input_genes(self):
        self.input_genes = []

        if self.use_attr_names:
            for variable in self.input_data.domain.attributes:
                self.input_genes .append(str(variable.attributes.get(self.gene_id_attribute, '?')))
        else:
            genes, _ = self.input_data.get_column_view(self.gene_id_column)
            self.input_genes = [str(g) for g in genes]

    def FilterAnnotatedGenes(self, genes):
        matchedgenes = self.annotations.get_gene_names_translator(genes).values()
        return matchedgenes, [gene for gene in genes if gene not in matchedgenes]

    def __start_download(self, files_list):
        # type: (List[Tuple[str, str]]) -> None
        task = EnsureDownloaded(files_list)
        task.progress.connect(self._progressBarSet)

        f = self._executor.submit(task)
        fw = FutureWatcher(f, self)
        fw.finished.connect(self.__download_finish)
        fw.finished.connect(fw.deleteLater)
        fw.resultReady.connect(self.__invalidate)

        self.progressBarInit(processEvents=None)
        self.setBlocking(True)
        self.setStatusMessage("Downloading")
        self.__state = State.Downloading

    @Slot(Future)
    def __download_finish(self, result):
        # type: (Future[None]) -> None
        assert QThread.currentThread() is self.thread()
        assert result.done()
        self.setBlocking(False)
        self.setStatusMessage("")
        self.progressBarFinished(processEvents=False)
        try:
            result.result()
        except ConnectTimeout:
            logging.getLogger(__name__).error("Error:")
            self.error(2, "Internet connection error, unable to load data. " +
                       "Check connection and create a new GO Browser widget.")
        except RequestException as err:
            logging.getLogger(__name__).error("Error:")
            self.error(2, "Internet error:\n" + str(err))
        except BaseException as err:
            logging.getLogger(__name__).error("Error:")
            self.error(2, "Error:\n" + str(err))
            raise
        else:
            self.error(2)
        finally:
            self.__state = State.Ready

    def __ensure_data(self):
        # Ensure that all required database (ontology and annotations for
        # the current selected organism are present. If not start a download in
        # the background. Return True if all dbs are present and false
        # otherwise
        assert self.__state == State.Ready
        annotation = self.available_annotations[self.annotation_index]
        go_files = [fname for domain, fname in serverfiles.listfiles(DOMAIN)]
        files = []

        if annotation.filename not in go_files:
            files.append(("GO", annotation.filename))

        if FILENAME_ONTOLOGY not in go_files:
            files.append((DOMAIN, FILENAME_ONTOLOGY))
        if files:
            self.__start_download(files)
            assert self.__state == State.Downloading
            return False
        else:
            return True

    def Load(self):
        a = self.available_annotations[self.annotation_index]

        if self.ontology is None:
            self.ontology = go.Ontology()

        if a.taxid != self.loaded_annotation_code:
            self.annotations = None
            gc.collect()  # Force run garbage collection
            self.annotations = go.Annotations(a.taxid)
            self.loaded_annotation_code = a.taxid
            count = defaultdict(int)
            geneSets = defaultdict(set)

            for anno in self.annotations.annotations:
                count[anno.evidence] += 1
                geneSets[anno.evidence].add(anno.gene_id)
            for etype in go.evidenceTypesOrdered:
                ecb = self.evidenceCheckBoxDict[etype]
                ecb.setEnabled(bool(count[etype]))
                ecb.setText(etype + ": %i annots(%i genes)" % (count[etype], len(geneSets[etype])))

    def Enrichment(self):
        assert self.input_data is not None
        assert self.__state == State.Ready

        if not self.annotations.ontology:
            self.annotations.ontology = self.ontology

        self.error(1)
        self.warning([0, 1])

        self.__get_input_genes()
        self.input_genes = set(self.input_genes)
        self.known_input_genes = self.annotations.get_genes_with_known_annotation(self.input_genes)

        # self.clusterGenes = clusterGenes = self.annotations.map_to_ncbi_id(self.input_genes).values()

        self.infoLabel.setText("%i unique genes on input\n%i (%.1f%%) genes with known annotations" %
                               (len(self.input_genes), len(self.known_input_genes),
                                100.0*len(self.known_input_genes)/len(self.input_genes)
                                if len(self.input_genes) else 0.0))

        if not self.useReferenceDataset or self.ref_data is None:
            self.information(2)
            self.information(1)
            self.ref_genes = self.annotations.genes()
            self.ref_genes = set(self.ref_genes)

        elif self.ref_data is not None:
            self.__get_ref_genes()
            self.ref_genes = set(self.ref_genes)

            ref_count = len(self.ref_genes)
            if ref_count == 0:
                self.ref_genes = self.annotations.genes()
                self.referenceRadioBox.buttons[1].setText("Reference set")
                self.referenceRadioBox.buttons[1].setDisabled(True)
                self.information(2, "Unable to extract gene names from reference dataset. "
                                    "Using entire genome for reference")
                self.useReferenceDataset = 0
            else:
                self.referenceRadioBox.buttons[1].setText("Reference set ({} genes)".format(ref_count))
                self.referenceRadioBox.buttons[1].setDisabled(False)
                self.information(2)
        else:
            self.useReferenceDataset = 0
            self.ref_genes = []

        if not self.ref_genes:
            self.error(1, "No valid reference set")
            return {}

        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
                evidences.append(etype)
        aspect = ['Process', 'Component', 'Function'][self.aspectIndex]

        self.progressBarInit(processEvents=False)
        self.setBlocking(True)
        self.__state = State.Running

        if self.input_genes:
            f = self._executor.submit(
                self.annotations.get_enriched_terms,
                self.input_genes, self.ref_genes, evidences, aspect=aspect,
                prob=self.probFunctions[self.probFunc], use_fdr=False,

                progress_callback=methodinvoke(
                    self, "_progressBarSet", (float,))
            )
            fw = FutureWatcher(f, parent=self)
            fw.done.connect(self.__on_enrichment_done)
            fw.done.connect(fw.deleteLater)
            return
        else:
            f = Future()
            f.set_result({})
            self.__on_enrichment_done(f)

    def __on_enrichment_done(self, results):
        # type: (Future[Dict[str, tuple]]) -> None
        self.progressBarFinished(processEvents=False)
        self.setBlocking(False)
        self.setStatusMessage("")
        if self.__state & State.Stale:
            self.__state = State.Ready
            self.__invalidate()
            return

        self.__state = State.Ready
        try:
            results = results.result()  # type: Dict[str, tuple]
        except Exception as ex:
            results = {}
            error = str(ex)
            self.error(1, error)

        if results:
            terms = list(results.items())
            fdr_vals = statistics.FDR([d[1] for _, d in terms])
            terms = [(key, d + (fdr,))
                     for (key, d), fdr in zip(terms, fdr_vals)]
            terms = dict(terms)

        else:
            terms = {}

        self.terms = terms

        if not self.terms:
            self.warning(0, "No enriched terms found.")
        else:
            self.warning(0)

        self.treeStructDict = {}
        ids = self.terms.keys()

        self.treeStructRootKey = None

        parents = {}
        for id in ids:
            parents[id] = set([term for _, term in self.ontology[id].related])

        children = {}
        for term in self.terms:
            children[term] = set([id for id in ids if term in parents[id]])

        for term in self.terms:
            self.treeStructDict[term] = TreeNode(self.terms[term], children[term])
            if not self.ontology[term].related and not getattr(self.ontology[term], "is_obsolete", False):
                self.treeStructRootKey = term

        self.SetGraph(terms)
        self._updateEnrichmentReportOutput()
        self.commit()

    def _updateEnrichmentReportOutput(self):
        terms = sorted(self.terms.items(), key=lambda item: item[1][1])
        # Create and send the enrichemnt report table.
        termsDomain = Orange.data.Domain(
            [], [],
            # All is meta!
            [Orange.data.StringVariable("GO Term Id"),
             Orange.data.StringVariable("GO Term Name"),
             Orange.data.ContinuousVariable("Cluster Frequency"),
             Orange.data.ContinuousVariable("Genes in Cluster",
                                            number_of_decimals=0),
             Orange.data.ContinuousVariable("Reference Frequency"),
             Orange.data.ContinuousVariable("Genes in Reference",
                                            number_of_decimals=0),
             Orange.data.ContinuousVariable("p-value"),
             Orange.data.ContinuousVariable("FDR"),
             Orange.data.ContinuousVariable("Enrichment"),
             Orange.data.StringVariable("Genes")])

        terms = [[t_id,
                  self.ontology[t_id].name,
                  len(genes) / len(self.input_genes),
                  len(genes),
                  r_count / len(self.ref_genes),
                  r_count,
                  p_value,
                  fdr,
                  len(genes) / len(self.input_genes) * \
                  len(self.ref_genes) / r_count,
                  ",".join(genes)
                  ]
                 for t_id, (genes, p_value, r_count, fdr) in terms
                 if genes and r_count]

        if terms:
            X = numpy.empty((len(terms), 0))
            M = numpy.array(terms, dtype=object)
            termsTable = Orange.data.Table.from_numpy(termsDomain, X,
                                                      metas=M)
        else:
            termsTable = None
        self.send("Enrichment Report", termsTable)

    @Slot(float)
    def _progressBarSet(self, value):
        assert QThread.currentThread() is self.thread()
        self.progressBarSet(value, processEvents=None)

    @Slot()
    def _progressBarFinish(self):
        assert QThread.currentThread() is self.thread()
        self.progressBarFinished(processEvents=None)

    def FilterGraph(self, graph):
        if self.filterByPValue_nofdr:
            graph = go.filterByPValue(graph, self.maxPValue_nofdr)
        if self.filterByPValue:  # FDR
            graph = dict(filter(lambda item: item[1][3] <= self.maxPValue, graph.items()))
        if self.filterByNumOfInstances:
            graph = dict(filter(lambda item: len(item[1][0]) >= self.minNumOfInstances, graph.items()))
        return graph

    def FilterAndDisplayGraph(self):
        if self.input_data and self.originalGraph is not None:
            self.graph = self.FilterGraph(self.originalGraph)
            if self.originalGraph and not self.graph:
                self.warning(1, "All found terms were filtered out.")
            else:
                self.warning(1)
            self.ClearGraph()
            self.DisplayGraph()

    def SetGraph(self, graph=None):
        self.originalGraph = graph
        if graph:
            self.FilterAndDisplayGraph()
        else:
            self.graph = {}
            self.ClearGraph()

    def ClearGraph(self):
        self.listView.clear()
        self.listViewItems=[]
        self.sigTerms.clear()

    def DisplayGraph(self):
        fromParentDict = {}
        self.termListViewItemDict = {}
        self.listViewItems = []

        def enrichment(t):
            try:
                return len(t[0]) / t[2] * (len(self.ref_genes) / len(self.input_genes))
            except ZeroDivisionError:
                # TODO: find out why this happens
                return 0

        maxFoldEnrichment = max([enrichment(term) for term in self.graph.values()] or [1])

        def addNode(term, parent, parentDisplayNode):
            if (parent, term) in fromParentDict:
                return
            if term in self.graph:
                displayNode = GOTreeWidgetItem(self.ontology[term], self.graph[term], len(self.input_genes),
                                               len(self.ref_genes), maxFoldEnrichment, parentDisplayNode)
                displayNode.goId = term
                self.listViewItems.append(displayNode)
                if term in self.termListViewItemDict:
                    self.termListViewItemDict[term].append(displayNode)
                else:
                    self.termListViewItemDict[term] = [displayNode]
                fromParentDict[(parent, term)] = True
                parent = term
            else:
                displayNode = parentDisplayNode

            for c in self.treeStructDict[term].children:
                addNode(c, parent, displayNode)

        if self.treeStructDict:
            addNode(self.treeStructRootKey, None, self.listView)

        terms = self.graph.items()
        terms = sorted(terms, key=lambda item: item[1][1])
        self.sigTableTermsSorted = [t[0] for t in terms]

        self.sigTerms.clear()
        for i, (t_id, (genes, p_value, refCount, fdr)) in enumerate(terms):
            item = GOTreeWidgetItem(self.ontology[t_id],
                                    (genes, p_value, refCount, fdr),
                                    len(self.input_genes),
                                    len(self.ref_genes),
                                    maxFoldEnrichment,
                                    self.sigTerms)
            item.goId = t_id

        self.listView.expandAll()
        for i in range(5):
            self.listView.resizeColumnToContents(i)
            self.sigTerms.resizeColumnToContents(i)
        self.sigTerms.resizeColumnToContents(6)
        width = min(self.listView.columnWidth(0), 350)
        self.listView.setColumnWidth(0, width)
        self.sigTerms.setColumnWidth(0, width)

    def ViewSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selected = self.listView.selectedItems()
        self.selectedTerms = list(set([lvi.term.id for lvi in selected]))
        self.ExampleSelection()
        self.selectionChanging = 0

    def TableSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selectedIds = set([self.sigTerms.itemFromIndex(index).goId for index in self.sigTerms.selectedIndexes()])

        for i in range(self.sigTerms.topLevelItemCount()):
            item = self.sigTerms.topLevelItem(i)
            selected = item.goId in selectedIds
            term = item.goId

            if selected:
                self.selectedTerms.append(term)

            for lvi in self.termListViewItemDict[term]:
                try:
                    lvi.setSelected(selected)
                    if selected:
                        lvi.setExpanded(True)
                except RuntimeError:  # Underlying C/C++ object deleted
                    pass
        self.selectionChanging = 0
        self.ExampleSelection()

    def ExampleSelection(self):
        self.commit()

    def commit(self):
        if self.input_data is None or self.originalGraph is None or \
                self.annotations is None:
            return
        if self.__state & State.Stale:
            return

        terms = set(self.selectedTerms)
        genes = reduce(operator.ior,
                       (set(self.graph[term][0]) for term in terms), set())

        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
                evidences.append(etype)

        allTerms = self.annotations.get_annotated_terms(
            genes, direct_annotation_only=self.selectionDirectAnnotation,
            evidence_codes=evidences)

        if self.selectionDisjoint > 0:
            count = defaultdict(int)
            for term in self.selectedTerms:
                for g in allTerms.get(term, []):
                    count[g] += 1
            ccount = 1 if self.selectionDisjoint == 1 else len(self.selectedTerms)
            selected_genes = [gene for gene, c in count.items() if c == ccount and gene in genes]
        else:
            selected_genes = reduce(
                operator.ior,
                (set(allTerms.get(term, [])) for term in self.selectedTerms), set())

        if self.use_attr_names:
            selected = [column for column in self.input_data.domain.attributes
                        if self.gene_id_attribute in column.attributes and
                        str(column.attributes[self.gene_id_attribute]) in set(selected_genes)]

            domain = Orange.data.Domain(selected, self.input_data.domain.class_vars, self.input_data.domain.metas)
            new_data = self.input_data.from_table(domain, self.input_data)
            self.send("Data on Selected Genes", new_data)

        else:
            selected_rows = []
            for row_index, row in enumerate(self.input_data):
                gene_in_row = str(row[self.gene_id_column])
                if gene_in_row in self.input_genes and gene_in_row in selected_genes:
                    selected_rows.append(row_index)

                if selected_rows:
                    selected = self.input_data[selected_rows]
                else:
                    selected = None

                self.send("Data on Selected Genes", selected)

    def ShowInfo(self):
        dialog = QDialog(self)
        dialog.setModal(False)
        dialog.setLayout(QVBoxLayout())
        label = QLabel(dialog)
        label.setText("Ontology:\n" + self.ontology.header
                      if self.ontology else "Ontology not loaded!")
        dialog.layout().addWidget(label)

        label = QLabel(dialog)
        label.setText("Annotations:\n" + self.annotations.header.replace("!", "")
                      if self.annotations else "Annotations not loaded!")
        dialog.layout().addWidget(label)
        dialog.show()

    def onDeleteWidget(self):
        """Called before the widget is removed from the canvas.
        """
        self.annotations = None
        self.ontology = None
        gc.collect()  # Force collection
Ejemplo n.º 27
0
class OWTestLearners(OWWidget):
    name = "Test & Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message(
            "Click on the table header to select shown columns",
            "click_header")]

    settingsHandler = settings.PerfectDomainContextHandler(metas_in_res=True)

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    BUILTIN_ORDER = {
        DiscreteVariable: ("AUC", "CA", "F1", "Precision", "Recall"),
        ContinuousVariable: ("MSE", "RMSE", "MAE", "R2")}

    shown_scores = \
        settings.Setting(set(chain(*BUILTIN_ORDER.values())))

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train data set is empty.")
        test_data_empty = Msg("Test data set is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg("Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train data sets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        only_one_class_var_value = Msg("Target variable has only one value.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(
            sbox, self, "resampling", callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_folds", label="Number of folds: ",
            items=[str(x) for x in self.NFolds], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.kfold_changed)
        gui.checkBox(
            ibox, self, "cv_stratified", "Stratified",
            callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(
            order=DomainModel.METAS, valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(
            ibox, self, "fold_feature", model=self.feature_model,
            orientation=Qt.Horizontal, callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(
            ibox, self, "n_repeats", label="Repeat train/test: ",
            items=[str(x) for x in self.NRepeats], maximumContentsLength=3,
            orientation=Qt.Horizontal, callback=self.shuffle_split_changed)
        gui.comboBox(
            ibox, self, "sample_size", label="Training set size: ",
            items=["{} %".format(x) for x in self.SampleSizes],
            maximumContentsLength=5, orientation=Qt.Horizontal,
            callback=self.shuffle_split_changed)
        gui.checkBox(
            ibox, self, "shuffle_stratified", "Stratified",
            callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox, self, "class_selection", items=[],
            sendSelectedValue=True, valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)

        self.view = gui.TableView(
            wordWrap=True,
        )
        header = self.view.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setStretchLastSection(False)
        header.setContextMenuPolicy(Qt.CustomContextMenu)
        header.customContextMenuRequested.connect(self.show_column_chooser)

        self.result_model = QStandardItemModel(self)
        self.result_model.setHorizontalHeaderLabels(["Method"])
        self.view.setModel(self.result_model)
        self.view.setItemDelegate(ItemDelegate())

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.view)

    def sizeHint(self):
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        else:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not len(data):
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [not data.domain.class_vars,
                     len(data.domain.class_vars) > 1,
                     data.domain.has_discrete_class and len(data.domain.class_var.values) == 1]
            errors = [self.Error.class_required,
                      self.Error.too_many_classes,
                      self.Error.only_one_class_var_value]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not len(data):
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {(True, True): " ",  # both, don't specify
                (True, False): " train ",
                (False, True): " test "}[(self.train_data_missing_vals,
                                          self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data is None or self.data.domain.class_var is None:
            self.scorers = []
            return
        class_var = self.data and self.data.domain.class_var
        order = {name: i
                 for i, name in enumerate(self.BUILTIN_ORDER[type(class_var)])}
        # 'abstract' is retrieved from __dict__ to avoid inheriting
        usable = (cls for cls in scoring.Score.registry.values()
                  if cls.is_scalar and not cls.__dict__.get("abstract")
                  and isinstance(class_var, cls.class_types))
        self.scorers = sorted(usable, key=lambda cls: order.get(cls.name, 99))

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self._update_header()
        self._update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._invalidate()
        self.__update()

    def _update_header(self):
        # Set the correct horizontal header labels on the results_model.
        model = self.result_model
        model.setColumnCount(1 + len(self.scorers))
        for col, score in enumerate(self.scorers):
            item = QStandardItem(score.name)
            item.setToolTip(score.long_name)
            model.setHorizontalHeaderItem(col + 1, item)
        self._update_shown_columns()

    def _update_shown_columns(self):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            header.setSectionHidden(section, col_name not in self.shown_scores)

    def _update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.view.model()
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            if isinstance(slot.results, Try.Fail):
                head.setToolTip(str(slot.results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                errors.append("{name} failed with error:\n"
                              "{exc.__class__.__name__}: {exc!s}"
                              .format(name=name, exc=slot.results.exception))

            row = [head]

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(
                        slot.results.value, target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [Try(scorer_caller(scorer, ovr_results))
                             for scorer in self.scorers]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat in stats:
                    item = QStandardItem()
                    if stat.success:
                        item.setText("{:.3f}".format(stat.value[0]))
                    else:
                        item.setToolTip(str(stat.exception))
                        has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self._update_stats_model()

    def _invalidate(self, which=None):
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.view.model()
        statmodelkeys = [model.item(row, 0).data(Qt.UserRole)
                         for row in range(model.rowCount())]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.__needupdate = True

    def show_column_chooser(self, pos):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        def update(col_name, checked):
            if checked:
                self.shown_scores.add(col_name)
            else:
                self.shown_scores.remove(col_name)
            self._update_shown_columns()

        menu = QMenu()
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            action = menu.addAction(col_name)
            action.setCheckable(True)
            action.setChecked(col_name in self.shown_scores)
            action.triggered.connect(partial(update, col_name))
        menu.exec(header.mapToGlobal(pos))

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [slot for slot in self.learners.values()
                 if slot.results is not None and slot.results.success]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [learner_name(slot.learner)
                                      for slot in valid]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".
                      format(stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [("Sampling type",
                      "{}Shuffle split, {} random samples with {}% data "
                      .format(stratified, self.NRepeats[self.n_repeats],
                              self.SampleSizes[self.sample_size]))]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value, processEvents=False)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Warning.test_data_missing.clear()
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        common_args = dict(
            store_data=True,
            preprocessor=self.preprocessor,
        )
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.KFold:
            folds = self.NFolds[self.n_folds]
            test_f = partial(
                Orange.evaluation.CrossValidation,
                self.data, learners_c, k=folds,
                random_state=rstate, **common_args)
        elif self.resampling == OWTestLearners.FeatureFold:
            test_f = partial(
                Orange.evaluation.CrossValidationFeature,
                self.data, learners_c, self.fold_feature,
                **common_args
            )
        elif self.resampling == OWTestLearners.LeaveOneOut:
            test_f = partial(
                Orange.evaluation.LeaveOneOut,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.ShuffleSplit:
            train_size = self.SampleSizes[self.sample_size] / 100
            test_f = partial(
                Orange.evaluation.ShuffleSplit,
                self.data, learners_c,
                n_resamples=self.NRepeats[self.n_repeats],
                train_size=train_size, test_size=None,
                stratified=self.shuffle_stratified,
                random_state=rstate, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTrain:
            test_f = partial(
                Orange.evaluation.TestOnTrainingData,
                self.data, learners_c, **common_args
            )
        elif self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(
                Orange.evaluation.TestOnTestData,
                self.data, self.test_data, learners_c, **common_args
            )
        else:
            assert False, "self.resampling %s" % self.resampling

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[float]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = Task()

        def progress_callback(finished):
            if task.cancelled:
                raise UserInterrupt()
            QMetaObject.invokeMethod(
                self, "setProgressValue", Qt.QueuedConnection,
                Q_ARG(float, 100 * finished)
            )

        def ondone(_):
            QMetaObject.invokeMethod(
                self, "__task_complete", Qt.QueuedConnection,
                Q_ARG(object, task))

        testfunc = partial(testfunc, callback=progress_callback)
        task.future = self.__executor.submit(testfunc)
        task.future.add_done_callback(ondone)

        self.progressBarInit(processEvents=None)
        self.setBlocking(True)
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, task):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        if self.__task is not task:
            assert task.cancelled
            log.debug("Reaping cancelled task: %r", "<>")
            return

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)
        self.setStatusMessage("")
        result = task.future
        assert result.done()
        self.__task = None
        try:
            results = result.result()    # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:
            log.exception("testing error (in __task_complete):",
                          exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er), er)))
            self.__state = State.Done
            return

        self.__state = State.Done

        learner_key = {slot.learner: key for key, slot in
                       self.learners.items()}
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [Try(scorer_caller(scorer, result))
                             for scorer in self.scorers]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self._update_header()
        self._update_stats_model()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            assert task.future.done()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Ejemplo n.º 28
0
class OWSignificantGroups(widget.OWWidget):
    name = 'Significant Groups'
    description = "Test whether instances grouped by nominal values are " \
                  "significantly different from random samples or the "\
                  "dataset in whole."
    icon = 'icons/SignificantGroups.svg'
    priority = 200

    class Inputs(widget.OWWidget.Inputs):
        data = widget.Input('Data', Table)

    class Outputs(widget.OWWidget.Outputs):
        selected_data = widget.Output('Selected Data', Table, default=True)
        data = widget.Output('Data', Table)
        results = widget.Output('Test Results', Table)

    want_main_area = True
    want_control_area = True

    class Information(widget.OWWidget.Information):
        nothing_significant = widget.Msg('Chosen parameters reveal no significant groups')

    class Error(widget.OWWidget.Error):
        no_vars_selected = widget.Msg('No independent variables selected')
        no_class_selected = widget.Msg('No dependent variable selected')

    TEST_STATISTICS = OrderedDict((
        ('mean', np.nanmean),
        ('variance', np.nanvar),
        ('median', np.nanmedian),
        ('minimum', np.nanmin),
        ('maximum', np.nanmax),
    ))

    settingsHandler = settings.DomainContextHandler()

    chosen_X = settings.ContextSetting([])
    chosen_y = settings.ContextSetting(0)
    is_permutation = settings.Setting(False)
    test_statistic = settings.Setting(next(iter(TEST_STATISTICS)))
    min_count = settings.Setting(20)

    def __init__(self):
        self._task = None  # type: Optional[self.Task]
        self._executor = ThreadExecutor(self)

        self.data = None
        self.test_type = ''

        self.discrete_model = DomainModel(separators=False, valid_types=(DiscreteVariable,), parent=self)
        self.domain_model = DomainModel(valid_types=DomainModel.PRIMITIVE, parent=self)

        box = gui.vBox(self.controlArea, 'Hypotheses Testing')
        gui.listView(box, self, 'chosen_X', model=self.discrete_model,
                     box='Grouping Variables',
                     selectionMode=QListView.ExtendedSelection,
                     callback=self.Error.no_vars_selected.clear,
                     toolTip='Select multiple variables with Ctrl+ or Shift+Click.')
        target = gui.comboBox(box, self, 'chosen_y',
                              sendSelectedValue=True,
                              label='Test Variable',
                              callback=[self.set_test_type,
                                        self.Error.no_class_selected.clear])
        target.setModel(self.domain_model)

        gui.checkBox(box, self, 'is_permutation', label='Permutation test',
                     callback=self.set_test_type)
        gui.comboBox(box, self, 'test_statistic', label='Statistic:',
                     items=tuple(self.TEST_STATISTICS),
                     orientation=Qt.Horizontal,
                     sendSelectedValue=True,
                     callback=self.set_test_type)
        gui.label(box, self, 'Test: %(test_type)s')

        box = gui.vBox(self.controlArea, 'Filter')
        gui.spin(box, self, 'min_count', 5, 1000, 5,
                 label='Minimum group size (count):')

        self.btn_compute = gui.button(self.controlArea, self, '&Compute', callback=self.compute)
        gui.rubber(self.controlArea)

        class Model(PyTableModel):
            _n_vars = 0
            _BACKGROUND = [QBrush(QColor('#eee')),
                           QBrush(QColor('#ddd'))]

            def setHorizontalHeaderLabels(self, labels, n_vars):
                self._n_vars = n_vars
                super().setHorizontalHeaderLabels(labels)

            def data(self, index, role=Qt.DisplayRole):
                if role == Qt.BackgroundRole and index.column() < self._n_vars:
                    return self._BACKGROUND[index.row() % 2]
                if role == Qt.DisplayRole or role == Qt.ToolTipRole:
                    colname = self.headerData(index.column(), Qt.Horizontal)
                    if colname.lower() in ('count', 'count | class'):
                        row = self.mapToSourceRows(index.row())
                        return int(self[row] [index.column()])
                return super().data(index, role)

        owwidget = self

        class View(gui.TableView):
            _vars = None
            def set_vars(self, vars):
                self._vars = vars

            def selectionChanged(self, *args):
                super().selectionChanged(*args)

                rows = list({index.row() for index in self.selectionModel().selectedRows(0)})

                if not rows:
                    owwidget.Outputs.data.send(None)
                    return

                model = self.model().tolist()
                filters = [Values([FilterDiscrete(self._vars[col], {model[row][col]})
                                   for col in range(len(self._vars))])
                           for row in self.model().mapToSourceRows(rows)]
                data = Values(filters, conjunction=False)(owwidget.data)

                annotated = create_annotated_table(owwidget.data, data.ids)

                owwidget.Outputs.selected_data.send(data)
                owwidget.Outputs.data.send(annotated)

        self.view = view = View(self)
        self.model = Model(parent=self)
        view.setModel(self.model)
        view.horizontalHeader().setStretchLastSection(False)
        self.mainArea.layout().addWidget(view)

        self.set_test_type()

    @Inputs.data
    def set_data(self, data):
        self.data = data
        domain = None if data is None else data.domain

        self.closeContext()

        self.domain_model.set_domain(domain)
        self.discrete_model.set_domain(domain)
        if domain is not None:
            if domain.class_var:
                self.chosen_y = domain.class_var.name

        self.openContext(domain)

        self.set_test_type()

    def set_test_type(self):
        if self.data is None:
            return

        yvar = self.data.domain[self.chosen_y]

        self.controls.test_statistic.setEnabled(yvar.is_continuous)

        if self.is_permutation:
            test = 'Permutation '
            if yvar.is_discrete:
                test += 'χ² '
            else:
                test += str(self.test_statistic) + ' '
        else:
            test = ''
            if yvar.is_discrete:
                test += 'χ² ' if len(yvar.values) > 2 else 'Hypergeometric '
            else:
                if self.test_statistic == 'mean':
                    test += "Student's t-"
                elif self.test_statistic == 'variance':
                    test += "Fligner–Killeen "
                elif self.test_statistic == 'median':
                    test += "Mann–Whitney U "
                elif self.test_statistic in ('minimum', 'maximum'):
                    test += "Gumbel distribution "
                else:
                    assert False, self.test_statistic
        test += 'test'
        self.test_type = test

    def compute(self):
        if not self.chosen_X:
            self.Error.no_vars_selected()
            return

        if not self.chosen_y:
            self.Error.no_class_selected()
            return

        self.btn_compute.setEnabled(False)
        yvar = self.data.domain[self.chosen_y]

        def get_col(var, col):
            values = np.array(list(var.values) + [np.nan], dtype=object)
            pd.Categorical(col, list(var.values))
            col = pd.Series(col).fillna(-1).astype(int)
            return values[col]

        X = np.column_stack([get_col(var, self.data.get_column_view(var)[0])
                             for var in (self.data.domain[i]
                                         for i in self.chosen_X)])
        X = pd.DataFrame(X, columns=self.chosen_X)
        y = pd.Series(self.data.get_column_view(yvar)[0])

        test, args, kwargs = None, (X, y), dict(min_count=self.min_count)
        if self.is_permutation:
            statistic = 'chi2' if yvar.is_discrete else self.TEST_STATISTICS[self.test_statistic]
            test = perm_test
            kwargs.update(
                statistic=statistic, n_jobs=-2,
                callback=methodinvoke(self, "setProgressValue", (int, int)))
        else:
            if yvar.is_discrete:
                if len(yvar.values) > 2:
                    test = chi2_test
                else:
                    test = hyper_test
                    args = (X, y.astype(bool))
            else:
                test = {
                    'mean': t_test,
                    'variance': fligner_killeen_test,
                    'median': mannwhitneyu_test,
                    'minimum': gumbel_min_test,
                    'maximum': gumbel_max_test,
                }[self.test_statistic]

        self._task = task = self.Task()
        self.progressBarInit()
        task.future = self._executor.submit(test, *args, **kwargs)
        task.watcher = FutureWatcher(task.future)
        task.watcher.done.connect(self.on_computed)

    @Slot(int, int)
    def setProgressValue(self, n, N):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(n / (N + 1) * 100)

    class Task:
        future = ...  # type: concurrent.futures.Future
        watcher = ...  # type: FutureWatcher
        cancelled = False  # type: bool

        def cancel(self):
            self.cancelled = True
            # Cancel the future. Note this succeeds only if the execution has
            # not yet started (see `concurrent.futures.Future.cancel`) ..
            self.future.cancel()
            # ... and wait until computation finishes
            concurrent.futures.wait([self.future])

    @Slot(concurrent.futures.Future)
    def on_computed(self, future):
        assert self.thread() is QThread.currentThread()
        assert future.done()

        self._task = None
        self.progressBarFinished()

        df = future.result()
        # Only retain "significant" p-values
        df = df[df[CORRECTED_LABEL] < .2]

        columns = [var.name for var in df.index.name] + list(df.columns)
        lst = [list(i) + list(j)
               for i, j in zip(df.index, df.values)]

        results_table = table_from_frame(pd.DataFrame(lst, columns=columns),
                                         force_nominal=True)
        results_table.name = 'Significant Groups'
        self.Outputs.results.send(results_table)

        self.view.set_vars(list(df.index.name))
        self.model.setHorizontalHeaderLabels(columns, len(df.index.name))
        self.model.wrap(lst)
        self.view.sortByColumn(len(columns) - 1, Qt.AscendingOrder)

        self.Information.nothing_significant(shown=not lst)
        self.btn_compute.setEnabled(True)

    def send_report(self):
        self.report_items([
            ('Test Variable', self.chosen_y),
            ('Test', self.test_type),
            ('Min. group size', self.min_count),
        ])
        self.report_table('Significant Groups', self.view)
Ejemplo n.º 29
0
class OWNNLearner(OWBaseLearner):
    name = "Neural Network"
    description = "A multi-layer perceptron (MLP) algorithm with " \
                  "backpropagation."
    icon = "icons/NN.svg"
    priority = 90
    keywords = ["mlp"]

    LEARNER = NNLearner

    activation = ["identity", "logistic", "tanh", "relu"]
    act_lbl = ["Identity", "Logistic", "tanh", "ReLu"]
    solver = ["lbfgs", "sgd", "adam"]
    solv_lbl = ["L-BFGS-B", "SGD", "Adam"]

    learner_name = Setting("Neural Network")
    hidden_layers_input = Setting("100,")
    activation_index = Setting(3)
    solver_index = Setting(2)
    max_iterations = Setting(200)
    alpha_index = Setting(0)
    settings_version = 1

    alphas = list(chain([x / 10000 for x in range(1, 10)],
                        [x / 1000 for x in range(1, 10)],
                        [x / 100 for x in range(1, 10)],
                        [x / 10 for x in range(1, 10)],
                        range(1, 10),
                        range(10, 100, 5),
                        range(100, 200, 10),
                        range(100, 1001, 50)))

    def add_main_layout(self):
        form = QFormLayout()
        form.setFieldGrowthPolicy(form.AllNonFixedFieldsGrow)
        form.setVerticalSpacing(25)
        gui.widgetBox(self.controlArea, True, orientation=form)
        form.addRow(
            "Neurons in hidden layers:",
            gui.lineEdit(
                None, self, "hidden_layers_input",
                orientation=Qt.Horizontal, callback=self.settings_changed,
                tooltip="A list of integers defining neurons. Length of list "
                        "defines the number of layers. E.g. 4, 2, 2, 3.",
                placeholderText="e.g. 100,"))
        form.addRow(
            "Activation:",
            gui.comboBox(
                None, self, "activation_index", orientation=Qt.Horizontal,
                label="Activation:", items=[i for i in self.act_lbl],
                callback=self.settings_changed))

        form.addRow(" ", gui.separator(None, 16))
        form.addRow(
            "Solver:",
            gui.comboBox(
                None, self, "solver_index", orientation=Qt.Horizontal,
                label="Solver:", items=[i for i in self.solv_lbl],
                callback=self.settings_changed))
        self.reg_label = QLabel()
        slider = gui.hSlider(
            None, self, "alpha_index",
            minValue=0, maxValue=len(self.alphas) - 1,
            callback=lambda: (self.set_alpha(), self.settings_changed()),
            createLabel=False)
        form.addRow(self.reg_label, slider)
        self.set_alpha()

        form.addRow(
            "Maximal number of iterations:",
            gui.spin(
                None, self, "max_iterations", 10, 10000, step=10,
                label="Max iterations:", orientation=Qt.Horizontal,
                alignment=Qt.AlignRight, callback=self.settings_changed))

    def set_alpha(self):
        self.strength_C = self.alphas[self.alpha_index]
        self.reg_label.setText("Regularization, α={}:".format(self.strength_C))

    @property
    def alpha(self):
        return self.alphas[self.alpha_index]

    def setup_layout(self):
        super().setup_layout()

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # just a test cancel button
        gui.button(self.apply_button, self, "Cancel", callback=self.cancel)

    def create_learner(self):
        return self.LEARNER(
            hidden_layer_sizes=self.get_hidden_layers(),
            activation=self.activation[self.activation_index],
            solver=self.solver[self.solver_index],
            alpha=self.alpha,
            max_iter=self.max_iterations,
            preprocessors=self.preprocessors)

    def get_learner_parameters(self):
        return (("Hidden layers", ', '.join(map(str, self.get_hidden_layers()))),
                ("Activation", self.act_lbl[self.activation_index]),
                ("Solver", self.solv_lbl[self.solver_index]),
                ("Alpha", self.alpha),
                ("Max iterations", self.max_iterations))

    def get_hidden_layers(self):
        layers = tuple(map(int, re.findall(r'\d+', self.hidden_layers_input)))
        if not layers:
            layers = (100,)
            self.hidden_layers_input = "100,"
        return layers

    def update_model(self):
        self.show_fitting_failed(None)
        self.model = None
        if self.check_data():
            self.__update()
        else:
            self.Outputs.model.send(self.model)

    @Slot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    def __update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        max_iter = self.learner.kwargs["max_iter"]

        # Setup the task state
        task = Task()
        lastemitted = 0.

        def callback(iteration):
            nonlocal task  # type: Task
            nonlocal lastemitted
            if task.isInterruptionRequested():
                raise CancelTaskException()
            progress = round(iteration / max_iter * 100)
            if progress != lastemitted:
                task.emitProgressUpdate(progress)
                lastemitted = progress

        # copy to set the callback so that the learner output is not modified
        # (currently we can not pass callbacks to learners __call__)
        learner = copy.copy(self.learner)
        learner.callback = callback

        def build_model(data, learner):
            try:
                return learner(data)
            except CancelTaskException:
                return None

        build_model_func = partial(build_model, self.data, learner)

        task.setFuture(self._executor.submit(build_model_func))
        task.done.connect(self._task_finished)
        task.progressChanged.connect(self.setProgressValue)

        self._task = task
        self.progressBarInit()
        self.setBlocking(True)

    @Slot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the built model
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()
        self._task.deleteLater()
        self._task = None
        self.setBlocking(False)
        self.progressBarFinished()

        try:
            self.model = f.result()
        except Exception as ex:  # pylint: disable=broad-except
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.model = None
            self.show_fitting_failed(ex)
        else:
            self.model.name = self.learner_name
            self.model.instances = self.data
            self.model.skl_model.orange_callback = None  # remove unpicklable callback
            self.Outputs.model.send(self.model)

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect from the task
            self._task.done.disconnect(self._task_finished)
            self._task.progressChanged.disconnect(self.setProgressValue)
            self._task.deleteLater()
            self._task = None

        self.progressBarFinished()
        self.setBlocking(False)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, settings, version):
        if not version:
            alpha = settings.pop("alpha", None)
            if alpha is not None:
                settings["alpha_index"] = \
                    np.argmin(np.abs(np.array(cls.alphas) - alpha))
Ejemplo n.º 30
0
class OWGOEnrichmentAnalysis(widget.OWWidget):
    name = "GO Browser"
    description = "Enrichment analysis for Gene Ontology terms."
    icon = "../widgets/icons/GOBrowser.svg"
    priority = 2020

    inputs = [("Cluster Data", Orange.data.Table, "setDataset",
               widget.Single + widget.Default),
              ("Reference Data", Orange.data.Table, "setReferenceDataset")]

    outputs = [("Data on Selected Genes", Orange.data.Table),
               ("Data on Unselected Genes", Orange.data.Table),
               ("Data on Unknown Genes", Orange.data.Table),
               ("Enrichment Report", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    annotationIndex = settings.ContextSetting(0)
    geneAttrIndex = settings.ContextSetting(0)
    useAttrNames = settings.ContextSetting(False)
    geneMatcherSettings = settings.Setting([True, False, False, False])
    useReferenceDataset = settings.Setting(False)
    aspectIndex = settings.Setting(0)

    useEvidenceType = settings.Setting(
        {et: True
         for et in go.evidenceTypesOrdered})

    filterByNumOfInstances = settings.Setting(False)
    minNumOfInstances = settings.Setting(1)
    filterByPValue = settings.Setting(True)
    maxPValue = settings.Setting(0.2)
    filterByPValue_nofdr = settings.Setting(False)
    maxPValue_nofdr = settings.Setting(0.01)
    probFunc = settings.Setting(0)

    selectionDirectAnnotation = settings.Setting(0)
    selectionDisjoint = settings.Setting(0)
    selectionAddTermAsClass = settings.Setting(0)

    Ready, Initializing, Running = 0, 1, 2

    def __init__(self, parent=None):
        super().__init__(self, parent)

        self.clusterDataset = None
        self.referenceDataset = None
        self.ontology = None
        self.annotations = None
        self.loadedAnnotationCode = "---"
        self.treeStructRootKey = None
        self.probFunctions = [stats.Binomial(), stats.Hypergeometric()]
        self.selectedTerms = []

        self.selectionChanging = 0
        self.__state = OWGOEnrichmentAnalysis.Initializing

        self.annotationCodes = []

        #############
        ## GUI
        #############
        self.tabs = gui.tabWidget(self.controlArea)
        ## Input tab
        self.inputTab = gui.createTabPage(self.tabs, "Input")
        box = gui.widgetBox(self.inputTab, "Info")
        self.infoLabel = gui.widgetLabel(box, "No data on input\n")

        gui.button(
            box,
            self,
            "Ontology/Annotation Info",
            callback=self.ShowInfo,
            tooltip="Show information on loaded ontology and annotations")

        box = gui.widgetBox(self.inputTab, "Organism")
        self.annotationComboBox = gui.comboBox(box,
                                               self,
                                               "annotationIndex",
                                               items=self.annotationCodes,
                                               callback=self._updateEnrichment,
                                               tooltip="Select organism")

        genebox = gui.widgetBox(self.inputTab, "Gene Names")
        self.geneAttrIndexCombo = gui.comboBox(
            genebox,
            self,
            "geneAttrIndex",
            callback=self._updateEnrichment,
            tooltip="Use this attribute to extract gene names from input data")
        self.geneAttrIndexCombo.setDisabled(self.useAttrNames)

        cb = gui.checkBox(genebox,
                          self,
                          "useAttrNames",
                          "Use column names",
                          tooltip="Use column names for gene names",
                          callback=self._updateEnrichment)
        cb.toggled[bool].connect(self.geneAttrIndexCombo.setDisabled)

        gui.button(genebox,
                   self,
                   "Gene matcher settings",
                   callback=self.UpdateGeneMatcher,
                   tooltip="Open gene matching settings dialog")

        self.referenceRadioBox = gui.radioButtonsInBox(
            self.inputTab,
            self,
            "useReferenceDataset", ["Entire genome", "Reference set (input)"],
            tooltips=[
                "Use entire genome for reference",
                "Use genes from Referece Examples input signal as reference"
            ],
            box="Reference",
            callback=self._updateEnrichment)

        self.referenceRadioBox.buttons[1].setDisabled(True)
        gui.radioButtonsInBox(
            self.inputTab,
            self,
            "aspectIndex",
            ["Biological process", "Cellular component", "Molecular function"],
            box="Aspect",
            callback=self._updateEnrichment)

        ## Filter tab
        self.filterTab = gui.createTabPage(self.tabs, "Filter")
        box = gui.widgetBox(self.filterTab, "Filter GO Term Nodes")
        gui.checkBox(
            box,
            self,
            "filterByNumOfInstances",
            "Genes",
            callback=self.FilterAndDisplayGraph,
            tooltip="Filter by number of input genes mapped to a term")
        ibox = gui.indentedBox(box)
        gui.spin(ibox,
                 self,
                 'minNumOfInstances',
                 1,
                 100,
                 step=1,
                 label='#:',
                 labelWidth=15,
                 callback=self.FilterAndDisplayGraph,
                 callbackOnReturn=True,
                 tooltip="Min. number of input genes mapped to a term")

        gui.checkBox(box,
                     self,
                     "filterByPValue_nofdr",
                     "p-value",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term p-value")

        gui.doubleSpin(gui.indentedBox(box),
                       self,
                       'maxPValue_nofdr',
                       1e-8,
                       1,
                       step=1e-8,
                       label='p:',
                       labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        #use filterByPValue for FDR, as it was the default in prior versions
        gui.checkBox(box,
                     self,
                     "filterByPValue",
                     "FDR",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term FDR")
        gui.doubleSpin(gui.indentedBox(box),
                       self,
                       'maxPValue',
                       1e-8,
                       1,
                       step=1e-8,
                       label='p:',
                       labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        box = gui.widgetBox(box, "Significance test")

        gui.radioButtonsInBox(box,
                              self,
                              "probFunc", ["Binomial", "Hypergeometric"],
                              tooltips=[
                                  "Use binomial distribution test",
                                  "Use hypergeometric distribution test"
                              ],
                              callback=self._updateEnrichment)
        box = gui.widgetBox(self.filterTab,
                            "Evidence codes in annotation",
                            addSpace=True)
        self.evidenceCheckBoxDict = {}
        for etype in go.evidenceTypesOrdered:
            ecb = QCheckBox(etype,
                            toolTip=go.evidenceTypes[etype],
                            checked=self.useEvidenceType[etype])
            ecb.toggled.connect(self.__on_evidenceChanged)
            box.layout().addWidget(ecb)
            self.evidenceCheckBoxDict[etype] = ecb

        ## Select tab
        self.selectTab = gui.createTabPage(self.tabs, "Select")
        box = gui.radioButtonsInBox(self.selectTab,
                                    self,
                                    "selectionDirectAnnotation",
                                    ["Directly or Indirectly", "Directly"],
                                    box="Annotated genes",
                                    callback=self.ExampleSelection)

        box = gui.widgetBox(self.selectTab, "Output", addSpace=True)
        gui.radioButtonsInBox(
            box,
            self,
            "selectionDisjoint",
            btnLabels=[
                "All selected genes", "Term-specific genes",
                "Common term genes"
            ],
            tooltips=[
                "Outputs genes annotated to all selected GO terms",
                "Outputs genes that appear in only one of selected GO terms",
                "Outputs genes common to all selected GO terms"
            ],
            callback=[self.ExampleSelection, self.UpdateAddClassButton])

        self.addClassCB = gui.checkBox(box,
                                       self,
                                       "selectionAddTermAsClass",
                                       "Add GO Term as class",
                                       callback=self.ExampleSelection)

        # ListView for DAG, and table for significant GOIDs
        self.DAGcolumns = [
            'GO term', 'Cluster', 'Reference', 'p-value', 'FDR', 'Genes',
            'Enrichment'
        ]

        self.splitter = QSplitter(Qt.Vertical, self.mainArea)
        self.mainArea.layout().addWidget(self.splitter)

        # list view
        self.listView = GOTreeWidget(self.splitter)
        self.listView.setSelectionMode(QTreeView.ExtendedSelection)
        self.listView.setAllColumnsShowFocus(1)
        self.listView.setColumnCount(len(self.DAGcolumns))
        self.listView.setHeaderLabels(self.DAGcolumns)

        self.listView.header().setSectionsClickable(True)
        self.listView.header().setSortIndicatorShown(True)
        self.listView.setSortingEnabled(True)
        self.listView.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))
        self.listView.setRootIsDecorated(True)

        self.listView.itemSelectionChanged.connect(self.ViewSelectionChanged)

        # table of significant GO terms
        self.sigTerms = QTreeWidget(self.splitter)
        self.sigTerms.setColumnCount(len(self.DAGcolumns))
        self.sigTerms.setHeaderLabels(self.DAGcolumns)
        self.sigTerms.setSortingEnabled(True)
        self.sigTerms.setSelectionMode(QTreeView.ExtendedSelection)
        self.sigTerms.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))

        self.sigTerms.itemSelectionChanged.connect(self.TableSelectionChanged)

        self.sigTableTermsSorted = []
        self.graph = {}

        self.inputTab.layout().addStretch(1)
        self.filterTab.layout().addStretch(1)
        self.selectTab.layout().addStretch(1)

        self.setBlocking(True)
        self._executor = ThreadExecutor()
        self._init = EnsureDownloaded([(taxonomy.Taxonomy.DOMAIN,
                                        taxonomy.Taxonomy.FILENAME),
                                       ("GO", "taxonomy.pickle")])
        self._init.finished.connect(self.__initialize_finish)
        self._executor.submit(self._init)

    def sizeHint(self):
        return QSize(1000, 700)

    def __initialize_finish(self):
        self.setBlocking(False)

        try:
            self.annotationFiles = listAvailable()
        except ConnectTimeout:
            self.error(2, "Internet connection error, unable to load data. " + \
                          "Check connection and create a new GO Browser widget.")
            self.filterTab.setEnabled(False)
            self.inputTab.setEnabled(False)
            self.selectTab.setEnabled(False)
            self.listView.setEnabled(False)
            self.sigTerms.setEnabled(False)
        else:
            self.annotationCodes = sorted(self.annotationFiles.keys())
            self.annotationComboBox.clear()
            self.annotationComboBox.addItems(self.annotationCodes)
            self.annotationComboBox.setCurrentIndex(self.annotationIndex)
            self.__state = OWGOEnrichmentAnalysis.Ready

    def __on_evidenceChanged(self):
        for etype, cb in self.evidenceCheckBoxDict.items():
            self.useEvidenceType[etype] = cb.isChecked()
        self._updateEnrichment()

    def UpdateGeneMatcher(self):
        """Open the Gene matcher settings dialog."""
        dialog = GeneMatcherDialog(self,
                                   defaults=self.geneMatcherSettings,
                                   modal=True)
        if dialog.exec_() != QDialog.Rejected:
            self.geneMatcherSettings = [
                getattr(dialog, item[0]) for item in dialog.items
            ]
            if self.annotations:
                self.SetGeneMatcher()
                self._updateEnrichment()

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.warning(0)
        self.warning(1)
        self.geneAttrIndexCombo.clear()
        self.ClearGraph()

        self.send("Data on Selected Genes", None)
        self.send("Data on Unselected Genes", None)
        self.send("Data on Unknown Genes", None)
        self.send("Enrichment Report", None)

    def setDataset(self, data=None):
        if self.__state == OWGOEnrichmentAnalysis.Initializing:
            self.__initialize_finish()

        self.closeContext()
        self.clear()
        self.clusterDataset = data

        if data is not None:
            domain = data.domain
            allvars = domain.variables + domain.metas
            self.candidateGeneAttrs = [var for var in allvars if isstring(var)]

            self.geneAttrIndexCombo.clear()
            for var in self.candidateGeneAttrs:
                self.geneAttrIndexCombo.addItem(*gui.attributeItem(var))
            taxid = data_hints.get_hint(data, "taxid", "")
            code = None
            try:
                code = go.from_taxid(taxid)
            except KeyError:
                pass
            except Exception as ex:
                print(ex)

            if code is not None:
                filename = "gene_association.%s.tar.gz" % code
                if filename in self.annotationFiles.values():
                    self.annotationIndex = \
                            [i for i, name in enumerate(self.annotationCodes) \
                             if self.annotationFiles[name] == filename].pop()

            self.useAttrNames = data_hints.get_hint(data, "genesinrows",
                                                    self.useAttrNames)
            self.openContext(data)

            self.geneAttrIndex = min(self.geneAttrIndex,
                                     len(self.candidateGeneAttrs) - 1)
            if len(self.candidateGeneAttrs) == 0:
                self.useAttrNames = True
                self.geneAttrIndex = -1
            elif self.geneAttrIndex < len(self.candidateGeneAttrs):
                self.geneAttrIndex = len(self.candidateGeneAttrs) - 1

            self._updateEnrichment()

    def setReferenceDataset(self, data=None):
        self.referenceDataset = data
        self.referenceRadioBox.buttons[1].setDisabled(not bool(data))
        self.referenceRadioBox.buttons[1].setText("Reference set")
        if self.clusterDataset is not None and self.useReferenceDataset:
            self.useReferenceDataset = 0 if not data else 1
            graph = self.Enrichment()
            self.SetGraph(graph)
        elif self.clusterDataset:
            self.__updateReferenceSetButton()

    def handleNewSignals(self):
        super().handleNewSignals()

    def _updateEnrichment(self):
        if self.clusterDataset is not None and \
                self.__state == OWGOEnrichmentAnalysis.Ready:
            pb = gui.ProgressBar(self, 100)
            self.Load(pb=pb)
            graph = self.Enrichment(pb=pb)
            self.FilterUnknownGenes()
            self.SetGraph(graph)

    def __updateReferenceSetButton(self):
        allgenes, refgenes = None, None
        if self.referenceDataset:
            try:
                allgenes = self.genesFromTable(self.referenceDataset)
            except Exception:
                allgenes = []
            refgenes, unknown = self.FilterAnnotatedGenes(allgenes)
        self.referenceRadioBox.buttons[1].setDisabled(not bool(allgenes))
        self.referenceRadioBox.buttons[1].setText("Reference set " + (
            "(%i genes, %i matched)" %
            (len(allgenes), len(refgenes)) if allgenes and refgenes else ""))

    def genesFromTable(self, data):
        if self.useAttrNames:
            genes = [v.name for v in data.domain.variables]
        else:
            attr = self.candidateGeneAttrs[min(
                self.geneAttrIndex,
                len(self.candidateGeneAttrs) - 1)]
            genes = [str(ex[attr]) for ex in data if not numpy.isnan(ex[attr])]
            if any("," in gene for gene in genes):
                self.information(
                    0,
                    "Separators detected in gene names. Assuming multiple genes per example."
                )
                genes = reduce(operator.iadd,
                               (genes.split(",") for genes in genes), [])
        return genes

    def FilterAnnotatedGenes(self, genes):
        matchedgenes = self.annotations.get_gene_names_translator(
            genes).values()
        return matchedgenes, [
            gene for gene in genes if gene not in matchedgenes
        ]

    def FilterUnknownGenes(self):
        if not self.useAttrNames and self.candidateGeneAttrs:
            geneAttr = self.candidateGeneAttrs[min(
                self.geneAttrIndex,
                len(self.candidateGeneAttrs) - 1)]
            indices = []
            for i, ex in enumerate(self.clusterDataset):
                if not any(
                        self.annotations.genematcher.match(n.strip())
                        for n in str(ex[geneAttr]).split(",")):
                    indices.append(i)
            if indices:
                data = self.clusterDataset[indices]
            else:
                data = None
            self.send("Data on Unknown Genes", data)
        else:
            self.send("Data on Unknown Genes", None)

    def Load(self, pb=None):

        if self.__state == OWGOEnrichmentAnalysis.Ready:
            go_files, tax_files = serverfiles.listfiles(
                "GO"), serverfiles.listfiles("Taxonomy")
            calls = []
            pb, finish = (gui.ProgressBar(self, 0),
                          True) if pb is None else (pb, False)
            count = 0
            if not tax_files:
                calls.append(("Taxonomy", "ncbi_taxnomy.tar.gz"))
                count += 1
            org = self.annotationCodes[min(self.annotationIndex,
                                           len(self.annotationCodes) - 1)]
            if org != self.loadedAnnotationCode:
                count += 1
                if self.annotationFiles[org] not in go_files:
                    calls.append(("GO", self.annotationFiles[org]))
                    count += 1

            if "gene_ontology_edit.obo.tar.gz" not in go_files:
                calls.append(("GO", "gene_ontology_edit.obo.tar.gz"))
                count += 1
            if not self.ontology:
                count += 1
            pb.iter += count * 100

            for args in calls:
                serverfiles.localpath_download(*args,
                                               **dict(callback=pb.advance))

            i = len(calls)
            if not self.ontology:
                self.ontology = go.Ontology(
                    progress_callback=lambda value: pb.advance())
                i += 1

            if org != self.loadedAnnotationCode:
                self.annotations = None
                gc.collect()  # Force run garbage collection
                code = self.annotationFiles[org].split(".")[-3]
                self.annotations = go.Annotations(
                    code,
                    genematcher=gene.GMDirect(),
                    progress_callback=lambda value: pb.advance())
                i += 1
                self.loadedAnnotationCode = org
                count = defaultdict(int)
                geneSets = defaultdict(set)

                for anno in self.annotations.annotations:
                    count[anno.evidence] += 1
                    geneSets[anno.evidence].add(anno.geneName)
                for etype in go.evidenceTypesOrdered:
                    ecb = self.evidenceCheckBoxDict[etype]
                    ecb.setEnabled(bool(count[etype]))
                    ecb.setText(etype + ": %i annots(%i genes)" %
                                (count[etype], len(geneSets[etype])))
            if finish:
                pb.finish()

    def SetGeneMatcher(self):
        if self.annotations:
            taxid = self.annotations.taxid
            matchers = []
            for matcher, use in zip(
                [gene.GMGO, gene.GMKEGG, gene.GMNCBI, gene.GMAffy],
                    self.geneMatcherSettings):
                if use:
                    try:
                        if taxid == "352472":
                            matchers.extend([
                                matcher(taxid),
                                gene.GMDicty(),
                                [matcher(taxid),
                                 gene.GMDicty()]
                            ])
                            # The reason machers are duplicated is that we want `matcher` or `GMDicty` to
                            # match genes by them self if possible. Only use the joint matcher if they fail.
                        else:
                            matchers.append(matcher(taxid))
                    except Exception as ex:
                        print(ex)
            self.annotations.genematcher = gene.matcher(matchers)
            self.annotations.genematcher.set_targets(
                self.annotations.gene_names)

    def Enrichment(self, pb=None):
        assert self.clusterDataset is not None

        pb = gui.ProgressBar(self, 100) if pb is None else pb
        if not self.annotations.ontology:
            self.annotations.ontology = self.ontology

        if isinstance(self.annotations.genematcher, gene.GMDirect):
            self.SetGeneMatcher()
        self.error(1)
        self.warning([0, 1])

        if self.useAttrNames:
            clusterGenes = [
                v.name for v in self.clusterDataset.domain.attributes
            ]
            self.information(0)
        elif 0 <= self.geneAttrIndex < len(self.candidateGeneAttrs):
            geneAttr = self.candidateGeneAttrs[self.geneAttrIndex]
            clusterGenes = [
                str(ex[geneAttr]) for ex in self.clusterDataset
                if not numpy.isnan(ex[geneAttr])
            ]
            if any("," in gene for gene in clusterGenes):
                self.information(
                    0,
                    "Separators detected in cluster gene names. Assuming multiple genes per example."
                )
                clusterGenes = reduce(operator.iadd,
                                      (genes.split(",")
                                       for genes in clusterGenes), [])
            else:
                self.information(0)
        else:
            self.error(1, "Failed to extract gene names from input dataset!")
            return {}

        genesSetCount = len(set(clusterGenes))

        self.clusterGenes = clusterGenes = self.annotations.get_gene_names_translator(
            clusterGenes).values()

        self.infoLabel.setText(
            "%i unique genes on input\n%i (%.1f%%) genes with known annotations"
            % (genesSetCount, len(clusterGenes), 100.0 * len(clusterGenes) /
               genesSetCount if genesSetCount else 0.0))

        referenceGenes = None
        if not self.useReferenceDataset or self.referenceDataset is None:
            self.information(2)
            self.information(1)
            referenceGenes = self.annotations.gene_names

        elif self.referenceDataset is not None:
            if self.useAttrNames:
                referenceGenes = [
                    v.name for v in self.referenceDataset.domain.attributes
                ]
                self.information(1)
            elif geneAttr in (self.referenceDataset.domain.variables +
                              self.referenceDataset.domain.metas):
                referenceGenes = [
                    str(ex[geneAttr]) for ex in self.referenceDataset
                    if not numpy.isnan(ex[geneAttr])
                ]
                if any("," in gene for gene in clusterGenes):
                    self.information(
                        1,
                        "Separators detected in reference gene names. Assuming multiple genes per example."
                    )
                    referenceGenes = reduce(operator.iadd,
                                            (genes.split(",")
                                             for genes in referenceGenes), [])
                else:
                    self.information(1)
            else:
                self.information(1)
                referenceGenes = None

            if referenceGenes is None:
                referenceGenes = list(self.annotations.gene_names)
                self.referenceRadioBox.buttons[1].setText("Reference set")
                self.referenceRadioBox.buttons[1].setDisabled(True)
                self.information(
                    2,
                    "Unable to extract gene names from reference dataset. Using entire genome for reference"
                )
                self.useReferenceDataset = 0
            else:
                refc = len(referenceGenes)
                referenceGenes = self.annotations.get_gene_names_translator(
                    referenceGenes).values()
                self.referenceRadioBox.buttons[1].setText(
                    "Reference set (%i genes, %i matched)" %
                    (refc, len(referenceGenes)))
                self.referenceRadioBox.buttons[1].setDisabled(False)
                self.information(2)
        else:
            self.useReferenceDataset = 0

        if not referenceGenes:
            self.error(1, "No valid reference set")
            return {}

        self.referenceGenes = referenceGenes
        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
                evidences.append(etype)
        aspect = ["P", "C", "F"][self.aspectIndex]

        if clusterGenes:
            self.terms = terms = self.annotations.get_enriched_terms(
                clusterGenes,
                referenceGenes,
                evidences,
                aspect=aspect,
                prob=self.probFunctions[self.probFunc],
                use_fdr=False,
                progress_callback=lambda value: pb.advance())
            ids = []
            pvals = []
            for i, d in self.terms.items():
                ids.append(i)
                pvals.append(d[1])
            for i, fdr in zip(ids, stats.FDR(
                    pvals)):  # save FDR as the last part of the tuple
                terms[i] = tuple(list(terms[i]) + [fdr])

        else:
            self.terms = terms = {}
        if not self.terms:
            self.warning(0, "No enriched terms found.")
        else:
            self.warning(0)

        pb.finish()
        self.treeStructDict = {}
        ids = self.terms.keys()

        self.treeStructRootKey = None

        parents = {}
        for id in ids:
            parents[id] = set([term for _, term in self.ontology[id].related])

        children = {}
        for term in self.terms:
            children[term] = set([id for id in ids if term in parents[id]])

        for term in self.terms:
            self.treeStructDict[term] = TreeNode(self.terms[term],
                                                 children[term])
            if not self.ontology[term].related and not getattr(
                    self.ontology[term], "is_obsolete", False):
                self.treeStructRootKey = term
        return terms

    def FilterGraph(self, graph):
        if self.filterByPValue_nofdr:
            graph = go.filterByPValue(graph, self.maxPValue_nofdr)
        if self.filterByPValue:  #FDR
            graph = dict(
                filter(lambda item: item[1][3] <= self.maxPValue,
                       graph.items()))
        if self.filterByNumOfInstances:
            graph = dict(
                filter(lambda item: len(item[1][0]) >= self.minNumOfInstances,
                       graph.items()))
        return graph

    def FilterAndDisplayGraph(self):
        if self.clusterDataset:
            self.graph = self.FilterGraph(self.originalGraph)
            if self.originalGraph and not self.graph:
                self.warning(1, "All found terms were filtered out.")
            else:
                self.warning(1)
            self.ClearGraph()
            self.DisplayGraph()

    def SetGraph(self, graph=None):
        self.originalGraph = graph
        if graph:
            self.FilterAndDisplayGraph()
        else:
            self.graph = {}
            self.ClearGraph()

    def ClearGraph(self):
        self.listView.clear()
        self.listViewItems = []
        self.sigTerms.clear()

    def DisplayGraph(self):
        fromParentDict = {}
        self.termListViewItemDict = {}
        self.listViewItems = []
        enrichment = lambda t: len(t[0]) / t[2] * (len(self.referenceGenes) /
                                                   len(self.clusterGenes))
        maxFoldEnrichment = max(
            [enrichment(term) for term in self.graph.values()] or [1])

        def addNode(term, parent, parentDisplayNode):
            if (parent, term) in fromParentDict:
                return
            if term in self.graph:
                displayNode = GOTreeWidgetItem(self.ontology[term],
                                               self.graph[term],
                                               len(self.clusterGenes),
                                               len(self.referenceGenes),
                                               maxFoldEnrichment,
                                               parentDisplayNode)
                displayNode.goId = term
                self.listViewItems.append(displayNode)
                if term in self.termListViewItemDict:
                    self.termListViewItemDict[term].append(displayNode)
                else:
                    self.termListViewItemDict[term] = [displayNode]
                fromParentDict[(parent, term)] = True
                parent = term
            else:
                displayNode = parentDisplayNode

            for c in self.treeStructDict[term].children:
                addNode(c, parent, displayNode)

        if self.treeStructDict:
            addNode(self.treeStructRootKey, None, self.listView)

        terms = self.graph.items()
        terms = sorted(terms, key=lambda item: item[1][1])
        self.sigTableTermsSorted = [t[0] for t in terms]

        self.sigTerms.clear()
        for i, (t_id, (genes, p_value, refCount, fdr)) in enumerate(terms):
            item = GOTreeWidgetItem(self.ontology[t_id],
                                    (genes, p_value, refCount, fdr),
                                    len(self.clusterGenes),
                                    len(self.referenceGenes),
                                    maxFoldEnrichment, self.sigTerms)
            item.goId = t_id

        self.listView.expandAll()
        for i in range(5):
            self.listView.resizeColumnToContents(i)
            self.sigTerms.resizeColumnToContents(i)
        self.sigTerms.resizeColumnToContents(6)
        width = min(self.listView.columnWidth(0), 350)
        self.listView.setColumnWidth(0, width)
        self.sigTerms.setColumnWidth(0, width)

        # Create and send the enrichemnt report table.
        termsDomain = Orange.data.Domain(
            [],
            [],
            # All is meta!
            [
                Orange.data.StringVariable("GO Term Id"),
                Orange.data.StringVariable("GO Term Name"),
                Orange.data.ContinuousVariable("Cluster Frequency"),
                Orange.data.ContinuousVariable("Genes in Cluster",
                                               number_of_decimals=0),
                Orange.data.ContinuousVariable("Reference Frequency"),
                Orange.data.ContinuousVariable("Genes in Reference",
                                               number_of_decimals=0),
                Orange.data.ContinuousVariable("p-value"),
                Orange.data.ContinuousVariable("FDR"),
                Orange.data.ContinuousVariable("Enrichment"),
                Orange.data.StringVariable("Genes")
            ])

        terms = [[t_id,
                  self.ontology[t_id].name,
                  len(genes) / len(self.clusterGenes),
                  len(genes),
                  r_count / len(self.referenceGenes),
                  r_count,
                  p_value,
                  fdr,
                  len(genes) / len(self.clusterGenes) * \
                  len(self.referenceGenes) / r_count,
                  ",".join(genes)
                  ]
                 for t_id, (genes, p_value, r_count, fdr) in terms]

        if terms:
            X = numpy.empty((len(terms), 0))
            M = numpy.array(terms, dtype=object)
            termsTable = Orange.data.Table.from_numpy(termsDomain, X, metas=M)
        else:
            termsTable = Orange.data.Table(termsDomain)
        self.send("Enrichment Report", termsTable)

    def ViewSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selected = self.listView.selectedItems()
        self.selectedTerms = list(set([lvi.term.id for lvi in selected]))
        self.ExampleSelection()
        self.selectionChanging = 0

    def TableSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selectedIds = set([
            self.sigTerms.itemFromIndex(index).goId
            for index in self.sigTerms.selectedIndexes()
        ])

        for i in range(self.sigTerms.topLevelItemCount()):
            item = self.sigTerms.topLevelItem(i)
            selected = item.goId in selectedIds
            term = item.goId

            if selected:
                self.selectedTerms.append(term)

            for lvi in self.termListViewItemDict[term]:
                try:
                    lvi.setSelected(selected)
                    if selected:
                        lvi.setExpanded(True)
                except RuntimeError:  # Underlying C/C++ object deleted
                    pass

        self.ExampleSelection()
        self.selectionChanging = 0

    def UpdateAddClassButton(self):
        self.addClassCB.setEnabled(self.selectionDisjoint == 1)

    def ExampleSelection(self):
        self.commit()

    def commit(self):
        if self.clusterDataset is None:
            return

        terms = set(self.selectedTerms)
        genes = reduce(operator.ior,
                       (set(self.graph[term][0]) for term in terms), set())

        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
                #             if getattr(self, "useEvidence" + etype):
                evidences.append(etype)
        allTerms = self.annotations.get_annotated_terms(
            genes,
            direct_annotation_only=self.selectionDirectAnnotation,
            evidence_codes=evidences)

        if self.selectionDisjoint > 0:
            count = defaultdict(int)
            for term in self.selectedTerms:
                for g in allTerms.get(term, []):
                    count[g] += 1
            ccount = 1 if self.selectionDisjoint == 1 else len(
                self.selectedTerms)
            selectedGenes = [
                gene for gene, c in count.items()
                if c == ccount and gene in genes
            ]
        else:
            selectedGenes = reduce(operator.ior,
                                   (set(allTerms.get(term, []))
                                    for term in self.selectedTerms), set())

        if self.useAttrNames:
            vars = [
                self.clusterDataset.domain[gene] for gene in set(selectedGenes)
            ]
            domain = Orange.data.Domain(vars,
                                        self.clusterDataset.domain.class_vars,
                                        self.clusterDataset.domain.metas)
            newdata = self.clusterDataset.from_table(domain,
                                                     self.clusterDataset)

            self.send("Data on Selected Genes", newdata)
            self.send("Data on Unselected Genes", None)
        elif self.candidateGeneAttrs:
            selectedExamples = []
            unselectedExamples = []

            geneAttr = self.candidateGeneAttrs[min(
                self.geneAttrIndex,
                len(self.candidateGeneAttrs) - 1)]

            if self.selectionDisjoint == 1:
                goVar = Orange.data.DiscreteVariable("GO Term",
                                                     values=list(
                                                         self.selectedTerms))
                newDomain = Orange.data.Domain(
                    self.clusterDataset.domain.variables, goVar,
                    self.clusterDataset.domain.metas)
                goColumn = []
            for i, ex in enumerate(self.clusterDataset):
                if not numpy.isnan(ex[geneAttr]) and any(
                        gene in selectedGenes
                        for gene in str(ex[geneAttr]).split(",")):
                    if self.selectionDisjoint == 1 and self.selectionAddTermAsClass:
                        terms = filter(
                            lambda term: any(gene in self.graph[term][0]
                                             for gene in str(ex[geneAttr]).
                                             split(",")), self.selectedTerms)
                        term = sorted(terms)[0]
                        goColumn.append(goVar.values.index(term))
                    selectedExamples.append(i)
                else:
                    unselectedExamples.append(i)

            if selectedExamples:
                selectedExamples = self.clusterDataset[selectedExamples]
                if self.selectionDisjoint == 1 and self.selectionAddTermAsClass:
                    selectedExamples = Orange.data.Table.from_table(
                        newDomain, selectedExamples)
                    view, issparse = selectedExamples.get_column_view(goVar)
                    assert not issparse
                    view[:] = goColumn
            else:
                selectedExamples = None

            if unselectedExamples:
                unselectedExamples = self.clusterDataset[unselectedExamples]
            else:
                unselectedExamples = None

            self.send("Data on Selected Genes", selectedExamples)
            self.send("Data on Unselected Genes", unselectedExamples)

    def ShowInfo(self):
        dialog = QDialog(self)
        dialog.setModal(False)
        dialog.setLayout(QVBoxLayout())
        label = QLabel(dialog)
        label.setText(
            "Ontology:\n" +
            self.ontology.header if self.ontology else "Ontology not loaded!")
        dialog.layout().addWidget(label)

        label = QLabel(dialog)
        label.setText("Annotations:\n" +
                      self.annotations.header.replace("!", "") if self.
                      annotations else "Annotations not loaded!")
        dialog.layout().addWidget(label)
        dialog.show()

    def onDeleteWidget(self):
        """Called before the widget is removed from the canvas.
        """
        self.annotations = None
        self.ontology = None
        gc.collect()  # Force collection
class OWExplainPredictions(OWWidget):

    name = "Explain Predictions"
    description = "Computes attribute contributions to the final prediction with an approximation algorithm for shapely value"
    icon = "icons/ExplainPredictions.svg"
    priority = 200
    gui_error = settings.Setting(0.05)
    gui_p_val = settings.Setting(0.05)
    gui_num_atr = settings.Setting(20)
    sort_index = settings.Setting(SortBy.ABSOLUTE)

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model, multiple=False)
        sample = Input("Sample", Table)

    class Outputs:
        explanations = Output("Explanations", Table)

    class Error(OWWidget.Error):
        sample_too_big = widget.Msg("Can only explain one sample at the time.")

    class Warning(OWWidget.Warning):
        unknowns_increased = widget.Msg(
            "Number of unknown values increased, Data and Sample domains mismatch."
        )

    def __init__(self):
        super().__init__()
        self.data = None
        self.model = None
        self.to_explain = None
        self.explanations = None
        self.stop = True
        self.e = None

        self._task = None
        self._executor = ThreadExecutor()

        info_box = gui.vBox(self.controlArea, "Info")
        self.data_info = gui.widgetLabel(info_box, "Data: N/A")
        self.model_info = gui.widgetLabel(info_box, "Model: N/A")
        self.sample_info = gui.widgetLabel(info_box, "Sample: N/A")

        criteria_box = gui.vBox(self.controlArea, "Stopping criteria")
        self.error_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_error",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error < ",
                                   spinType=float,
                                   callback=self._update_error_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        self.p_val_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_p_val",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error p-value < ",
                                   spinType=float,
                                   callback=self._update_p_val_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        plot_properties_box = gui.vBox(self.controlArea, "Display features")
        self.num_atr_spin = gui.spin(plot_properties_box,
                                     self,
                                     "gui_num_atr",
                                     1,
                                     100,
                                     step=1,
                                     label="Show attributes",
                                     callback=self._update_num_atr_spin,
                                     controlWidth=80,
                                     keyboardTracking=False)

        self.sort_combo = gui.comboBox(plot_properties_box,
                                       self,
                                       "sort_index",
                                       label="Rank by",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._update_combo)

        gui.rubber(self.controlArea)

        self.cancel_button = gui.button(
            self.controlArea,
            self,
            "Stop Computation",
            callback=self.toggle_button,
            autoDefault=True,
            tooltip="Stops and restarts computation")
        self.cancel_button.setDisabled(True)

        predictions_box = gui.vBox(self.mainArea, "Model prediction")
        self.predict_info = gui.widgetLabel(predictions_box, "")

        self.mainArea.setMinimumWidth(700)
        self.resize(700, 400)

        class _GraphicsView(QGraphicsView):
            def __init__(self, scene, parent, **kwargs):
                for k, v in dict(
                        verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                        horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                        viewportUpdateMode=QGraphicsView.
                        BoundingRectViewportUpdate,
                        renderHints=(QPainter.Antialiasing
                                     | QPainter.TextAntialiasing
                                     | QPainter.SmoothPixmapTransform),
                        alignment=(Qt.AlignTop | Qt.AlignLeft),
                        sizePolicy=QSizePolicy(
                            QSizePolicy.MinimumExpanding,
                            QSizePolicy.MinimumExpanding)).items():
                    kwargs.setdefault(k, v)
                super().__init__(scene, parent, **kwargs)

        class GraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(
                    scene,
                    parent,
                    verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                    styleSheet='QGraphicsView {background: white}')
                self.viewport().setMinimumWidth(500)
                self._is_resizing = False

            w = self

            def resizeEvent(self, resizeEvent):
                self._is_resizing = True
                self.w.draw()
                self._is_resizing = False
                return super().resizeEvent(resizeEvent)

            def is_resizing(self):
                return self._is_resizing

            def sizeHint(self):
                return QSize(600, 300)

        class FixedSizeGraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene,
                                 parent,
                                 sizePolicy=QSizePolicy(
                                     QSizePolicy.MinimumExpanding,
                                     QSizePolicy.Minimum))

            def sizeHint(self):
                return QSize(600, 30)

        """all will share the same scene, but will show different parts of it"""
        self.box_scene = QGraphicsScene(self)

        self.box_view = GraphicsView(self.box_scene, self)
        self.header_view = FixedSizeGraphicsView(self.box_scene, self)
        self.footer_view = FixedSizeGraphicsView(self.box_scene, self)

        self.mainArea.layout().addWidget(self.header_view)
        self.mainArea.layout().addWidget(self.box_view)
        self.mainArea.layout().addWidget(self.footer_view)

        self.painter = None

    def draw(self):
        """Uses GraphAttributes class to draw the explanaitons """
        self.box_scene.clear()
        wp = self.box_view.viewport().rect()
        header_height = 30
        if self.explanations is not None:
            self.painter = GraphAttributes(
                self.box_scene,
                min(self.gui_num_atr, self.explanations.Y.shape[0]))
            self.painter.paint(wp, self.explanations, header_h=header_height)
        """set appropriate boxes for different views"""
        rect = QRectF(self.box_scene.itemsBoundingRect().x(),
                      self.box_scene.itemsBoundingRect().y(),
                      self.box_scene.itemsBoundingRect().width(),
                      self.box_scene.itemsBoundingRect().height())

        self.box_scene.setSceneRect(rect)
        self.box_view.setSceneRect(rect.x(),
                                   rect.y() + header_height + 2, rect.width(),
                                   rect.height() - 80)
        self.header_view.setSceneRect(rect.x(), rect.y(), rect.width(), 10)
        self.header_view.setFixedHeight(header_height)
        self.footer_view.setSceneRect(rect.x(),
                                      rect.y() + rect.height() - 50,
                                      rect.width(), 35)

    def sort_explanations(self):
        """sorts explanations according to users choice from combo box"""
        if self.sort_index == SortBy.POSITIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])][::-1]
        elif self.sort_index == SortBy.NEGATIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])]
        elif self.sort_index == SortBy.ABSOLUTE:
            self.explanations = self.explanations[np.argsort(
                np.abs(self.explanations.X[:, 0]))][::-1]
        elif self.sort_index == SortBy.BY_NAME:
            l = np.array(
                list(map(np.chararray.lower, self.explanations.metas[:, 0])))
            self.explanations = self.explanations[np.argsort(l)]
        else:
            return

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input 'Data"""
        self.data = data
        self.explanations = None
        self.data_info.setText("Data: N/A")
        self.e = None
        if data is not None:
            model = TableModel(data, parent=None)
            if data.X.shape[0] == 1:
                inst = "1 instance and "
            else:
                inst = str(data.X.shape[0]) + " instances and "
            if data.X.shape[1] == 1:
                feat = "1 feature "
            else:
                feat = str(data.X.shape[1]) + " features"
            self.data_info.setText("Data: " + inst + feat)

    @Inputs.model
    def set_predictor(self, model):
        """Set input 'Model"""
        self.model = model
        self.model_info.setText("Model: N/A")
        self.explanations = None
        self.e = None
        if model is not None:
            self.model_info.setText("Model: " + str(model.name))

    @Inputs.sample
    @check_sql_input
    def set_sample(self, sample):
        """Set input 'Sample', checks if size is appropriate"""
        self.to_explain = sample
        self.explanations = None
        self.Error.sample_too_big.clear()
        self.sample_info.setText("Sample: N/A")
        if sample is not None:
            if len(sample.X) != 1:
                self.to_explain = None
                self.Error.sample_too_big()
            else:
                if sample.X.shape[1] == 1:
                    feat = "1 feature"
                else:
                    feat = str(sample.X.shape[1]) + " features"
                self.sample_info.setText("Sample: " + feat)
                if self.e is not None:
                    self.e.saved = False

    def handleNewSignals(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.predict_info.setText("")
        self.Warning.unknowns_increased.clear()
        self.stop = True
        self.cancel_button.setText("Stop Computation")
        self.commit_calc_or_output()

    def commit_calc_or_output(self):
        if self.data is not None and self.to_explain is not None:
            self.commit_calc()
        else:
            self.commit_output()

    def commit_calc(self):
        num_nan = np.count_nonzero(np.isnan(self.to_explain.X[0]))

        self.to_explain = self.to_explain.transform(self.data.domain)
        if num_nan != np.count_nonzero(np.isnan(self.to_explain.X[0])):
            self.Warning.unknowns_increased()
        if self.model is not None:
            # calculate contributions
            if self.e is None:
                self.e = ExplainPredictions(self.data,
                                            self.model,
                                            batch_size=min(
                                                len(self.data.X), 500),
                                            p_val=self.gui_p_val,
                                            error=self.gui_error)
            self._task = task = Task()

            def callback(progress):
                nonlocal task
                # update progress bar
                QMetaObject.invokeMethod(self, "set_progress_value",
                                         Qt.QueuedConnection,
                                         Q_ARG(int, progress))
                if task.canceled:
                    return True
                return False

            def callback_update(table):
                QMetaObject.invokeMethod(self, "update_view",
                                         Qt.QueuedConnection,
                                         Q_ARG(Orange.data.Table, table))

            def callback_prediction(class_value):
                QMetaObject.invokeMethod(self, "update_model_prediction",
                                         Qt.QueuedConnection,
                                         Q_ARG(float, class_value))

            self.was_canceled = False
            explain_func = partial(self.e.anytime_explain,
                                   self.to_explain[0],
                                   callback=callback,
                                   update_func=callback_update,
                                   update_prediction=callback_prediction)

            self.progressBarInit(processEvents=None)
            task.future = self._executor.submit(explain_func)
            task.watcher = FutureWatcher(task.future)
            task.watcher.done.connect(self._task_finished)
            self.cancel_button.setDisabled(False)

    @pyqtSlot(Orange.data.Table)
    def update_view(self, table):
        self.explanations = table
        self.sort_explanations()
        self.draw()
        self.commit_output()

    @pyqtSlot(float)
    def update_model_prediction(self, value):
        self._print_prediction(value)

    @pyqtSlot(int)
    def set_progress_value(self, value):
        self.progressBarSet(value, processEvents=False)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters:
        ----------
        f: conncurent.futures.Future
            future instance holding the result of learner evaluation
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None

        if not self.was_canceled:
            self.cancel_button.setDisabled(True)

        try:
            results = f.result()
        except Exception as ex:
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occured during evaluation: {!r}".format(ex))

            for key in self.results.keys():
                self.results[key] = None
        else:
            self.update_view(results[1])

        self.progressBarFinished(processEvents=False)

    def commit_output(self):
        """
        Sends best-so-far results forward
        """
        self.Outputs.explanations.send(self.explanations)

    def toggle_button(self):
        if self.stop:
            self.stop = False
            self.cancel_button.setText("Restart Computation")
            self.cancel()
        else:
            self.stop = True
            self.cancel_button.setText("Stop Computation")
            self.commit_calc_or_output()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self.was_canceled = True
            self._task_finished(self._task.future)

    def _print_prediction(self, class_value):
        """
        Parameters
        ----------
        class_value: float 
            Number representing either index of predicted class value, looked up in domain, or predicted value (regression)
        """
        name = self.data.domain.class_vars[0].name
        if isinstance(self.data.domain.class_vars[0], ContinuousVariable):
            self.predict_info.setText(name + ":      " + str(class_value))
        else:
            self.predict_info.setText(
                name + ":      " +
                self.data.domain.class_vars[0].values[int(class_value)])

    def _update_error_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.error = self.gui_error
        self.handleNewSignals()

    def _update_p_val_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.p_val = self.gui_p_val
        self.handleNewSignals()

    def _update_num_atr_spin(self):
        self.cancel()
        self.handleNewSignals()

    def _update_combo(self):
        if self.explanations != None:
            self.sort_explanations()
            self.draw()
            self.commit_output()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Ejemplo n.º 32
0
class OWGeneInfo(widget.OWWidget):
    name = "Gene Info"
    description = "Displays gene information from NCBI and other sources."
    icon = "../widgets/icons/GeneInfo.svg"
    priority = 2010

    inputs = [("Data", Orange.data.Table, "setData")]
    outputs = [("Data Subset", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    organism_index = settings.ContextSetting(0)
    taxid = settings.ContextSetting("9606")

    gene_attr = settings.ContextSetting(0)

    auto_commit = settings.Setting(False)
    search_string = settings.Setting("")

    useAttr = settings.ContextSetting(False)
    useAltSource = settings.ContextSetting(False)

    def __init__(
        self,
        parent=None,
    ):
        super().__init__(self, parent)

        self.selectionChangedFlag = False

        self.__initialized = False
        self.initfuture = None
        self.itemsfuture = None

        self.infoLabel = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info", addSpace=True),
            "Initializing\n")

        self.organisms = None
        self.organismBox = gui.widgetBox(self.controlArea,
                                         "Organism",
                                         addSpace=True)

        self.organismComboBox = gui.comboBox(
            self.organismBox,
            self,
            "organism_index",
            callback=self._onSelectedOrganismChanged)

        # For now only support one alt source, with a checkbox
        # In the future this can be extended to multiple selections
        self.altSourceCheck = gui.checkBox(self.organismBox,
                                           self,
                                           "useAltSource",
                                           "Show information from dictyBase",
                                           callback=self.onAltSourceChange)

        self.altSourceCheck.hide()

        box = gui.widgetBox(self.controlArea, "Gene names", addSpace=True)
        self.geneAttrComboBox = gui.comboBox(box,
                                             self,
                                             "gene_attr",
                                             "Gene attribute",
                                             callback=self.updateInfoItems)
        self.geneAttrComboBox.setEnabled(not self.useAttr)
        cb = gui.checkBox(box,
                          self,
                          "useAttr",
                          "Use attribute names",
                          callback=self.updateInfoItems)
        cb.toggled[bool].connect(self.geneAttrComboBox.setDisabled)

        gui.auto_commit(self.controlArea, self, "auto_commit", "Commit")

        # A label for dictyExpress link (Why oh god why???)
        self.dictyExpressBox = gui.widgetBox(self.controlArea, "Dicty Express")
        self.linkLabel = gui.widgetLabel(self.dictyExpressBox, "")
        self.linkLabel.setOpenExternalLinks(False)
        self.linkLabel.linkActivated.connect(self.onDictyExpressLink)

        self.dictyExpressBox.hide()

        gui.rubber(self.controlArea)

        gui.lineEdit(self.mainArea,
                     self,
                     "search_string",
                     "Filter",
                     callbackOnType=True,
                     callback=self.searchUpdate)

        self.treeWidget = QTreeView(self.mainArea,
                                    selectionMode=QTreeView.ExtendedSelection,
                                    rootIsDecorated=False,
                                    uniformRowHeights=True,
                                    sortingEnabled=True)

        self.treeWidget.setItemDelegate(
            gui.LinkStyledItemDelegate(self.treeWidget))
        self.treeWidget.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.treeWidget)

        box = gui.widgetBox(self.mainArea, "", orientation="horizontal")
        gui.button(box, self, "Select Filtered", callback=self.selectFiltered)
        gui.button(box,
                   self,
                   "Clear Selection",
                   callback=self.treeWidget.clearSelection)

        self.geneinfo = []
        self.cells = []
        self.row2geneinfo = {}
        self.data = None

        # : (# input genes, # matches genes)
        self.matchedInfo = 0, 0

        self.setBlocking(True)
        self.executor = ThreadExecutor(self)

        self.progressBarInit()

        task = Task(
            function=partial(taxonomy.ensure_downloaded,
                             callback=methodinvoke(self, "advance", ())))

        task.resultReady.connect(self.initialize)
        task.exceptionReady.connect(self._onInitializeError)

        self.initfuture = self.executor.submit(task)

    def sizeHint(self):
        return QSize(1024, 720)

    @Slot()
    def advance(self):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(self.progressBarValue + 1, processEvents=None)

    def initialize(self):
        if self.__initialized:
            # Already initialized
            return
        self.__initialized = True

        self.organisms = sorted(
            set([
                name.split(".")[-2]
                for name in serverfiles.listfiles("NCBI_geneinfo")
            ] + gene.NCBIGeneInfo.common_taxids()))

        self.organismComboBox.addItems(
            [taxonomy.name(tax_id) for tax_id in self.organisms])
        if self.taxid in self.organisms:
            self.organism_index = self.organisms.index(self.taxid)
        else:
            self.organism_index = 0
            self.taxid = self.organisms[self.organism_index]

        self.altSourceCheck.setVisible(self.taxid == DICTY_TAXID)
        self.dictyExpressBox.setVisible(self.taxid == DICTY_TAXID)

        self.infoLabel.setText("No data on input\n")
        self.initfuture = None

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)

    def _onInitializeError(self, exc):
        sys.excepthook(type(exc), exc, None)
        self.error(0, "Could not download the necessary files.")

    def _onSelectedOrganismChanged(self):
        assert 0 <= self.organism_index <= len(self.organisms)
        self.taxid = self.organisms[self.organism_index]
        self.altSourceCheck.setVisible(self.taxid == DICTY_TAXID)
        self.dictyExpressBox.setVisible(self.taxid == DICTY_TAXID)

        if self.data is not None:
            self.updateInfoItems()

    def setData(self, data=None):
        if not self.__initialized:
            self.initfuture.result()
            self.initialize()

        if self.itemsfuture is not None:
            raise Exception("Already processing")

        self.closeContext()
        self.data = data

        if data is not None:
            self.geneAttrComboBox.clear()
            self.attributes = \
                [attr for attr in data.domain.variables + data.domain.metas
                 if isinstance(attr, (Orange.data.StringVariable,
                                      Orange.data.DiscreteVariable))]

            for var in self.attributes:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            self.taxid = data_hints.get_hint(self.data, "taxid", self.taxid)
            self.useAttr = data_hints.get_hint(self.data, "genesinrows",
                                               self.useAttr)

            self.openContext(data)
            self.gene_attr = min(self.gene_attr, len(self.attributes) - 1)

            if self.taxid in self.organisms:
                self.organism_index = self.organisms.index(self.taxid)
            else:
                self.organism_index = 0
                self.taxid = self.organisms[self.organism_index]

            self.updateInfoItems()
        else:
            self.clear()

    def infoSource(self):
        """ Return the current selected info source getter function from
        INFO_SOURCES
        """
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        if org not in INFO_SOURCES:
            org = "default"
        sources = INFO_SOURCES[org]
        name, func = sources[min(self.useAltSource, len(sources) - 1)]
        return name, func

    def inputGenes(self):
        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [
                str(ex[attr]) for ex in self.data if not math.isnan(ex[attr])
            ]
        else:
            genes = []
        return genes

    def updateInfoItems(self):
        self.warning(0)
        if self.data is None:
            return

        genes = self.inputGenes()
        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [
                str(ex[attr]) for ex in self.data if not math.isnan(ex[attr])
            ]
        else:
            genes = []
        if not genes:
            self.warning(0, "Could not extract genes from input dataset.")

        self.warning(1)
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        source_name, info_getter = self.infoSource()

        self.error(0)

        self.updateDictyExpressLink(genes, show=org == DICTY_TAXID)
        self.altSourceCheck.setVisible(org == DICTY_TAXID)

        self.progressBarInit()
        self.setBlocking(True)
        self.setEnabled(False)
        self.infoLabel.setText("Retrieving info records.\n")

        self.genes = genes

        task = Task(function=partial(
            info_getter, org, genes, advance=methodinvoke(self, "advance", (
            ))))
        self.itemsfuture = self.executor.submit(task)
        task.finished.connect(self._onItemsCompleted)

    def _onItemsCompleted(self):
        self.setBlocking(False)
        self.progressBarFinished()
        self.setEnabled(True)

        try:
            schema, geneinfo = self.itemsfuture.result()
        finally:
            self.itemsfuture = None

        self.geneinfo = geneinfo = list(zip(self.genes, geneinfo))
        self.cells = cells = []
        self.row2geneinfo = {}
        links = []
        for i, (_, gi) in enumerate(geneinfo):
            if gi:
                row = []
                for _, item in zip(schema, gi):
                    if isinstance(item, Link):
                        # TODO: This should be handled by delegates
                        row.append(item.text)
                        links.append(item.link)
                    else:
                        row.append(item)
                cells.append(row)
                self.row2geneinfo[len(cells) - 1] = i

        model = TreeModel(cells, [str(col) for col in schema], None)

        model.setColumnLinks(0, links)
        proxyModel = QSortFilterProxyModel(self)
        proxyModel.setSourceModel(model)
        self.treeWidget.setModel(proxyModel)
        self.treeWidget.selectionModel().selectionChanged.connect(self.commit)

        for i in range(7):
            self.treeWidget.resizeColumnToContents(i)
            self.treeWidget.setColumnWidth(
                i, min(self.treeWidget.columnWidth(i), 200))

        self.infoLabel.setText("%i genes\n%i matched NCBI's IDs" %
                               (len(self.genes), len(cells)))
        self.matchedInfo = len(self.genes), len(cells)

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.treeWidget.setModel(
            TreeModel([], [
                "NCBI ID", "Symbol", "Locus Tag", "Chromosome", "Description",
                "Synonyms", "Nomenclature"
            ], self.treeWidget))

        self.geneAttrComboBox.clear()
        self.send("Data Subset", None)

    def commit(self):
        if self.data is None:
            self.send("Data Subset", None)
            return

        model = self.treeWidget.model()
        selection = self.treeWidget.selectionModel().selection()
        selection = model.mapSelectionToSource(selection)
        selectedRows = list(
            chain(*(range(r.top(),
                          r.bottom() + 1) for r in selection)))

        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)
        gene2row = dict((self.geneinfo[self.row2geneinfo[row]][0], row)
                        for row in selectedRows)

        isselected = selectedIds.__contains__

        if self.useAttr:

            def is_selected(attr):
                return attr.name in selectedIds

            attrs = [
                attr for attr in self.data.domain.attributes
                if isselected(attr.name)
            ]
            domain = Orange.data.Domain(attrs, self.data.domain.class_vars,
                                        self.data.domain.metas)
            newdata = self.data.from_table(domain, self.data)
            self.send("Data Subset", newdata)

        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            gene_col = [
                attr.str_val(v) for v in self.data.get_column_view(attr)[0]
            ]
            gene_col = [(i, name) for i, name in enumerate(gene_col)
                        if isselected(name)]
            indices = [i for i, _ in gene_col]

            # Add a gene info columns to the output
            headers = [
                str(model.headerData(i, Qt.Horizontal, Qt.DisplayRole))
                for i in range(model.columnCount())
            ]
            metas = [Orange.data.StringVariable(name) for name in headers]
            domain = Orange.data.Domain(self.data.domain.attributes,
                                        self.data.domain.class_vars,
                                        self.data.domain.metas + tuple(metas))

            newdata = self.data.from_table(domain, self.data)[indices]

            model_rows = [gene2row[gene] for _, gene in gene_col]
            for col, meta in zip(range(model.columnCount()), metas):
                col_data = [
                    str(model.index(row, col).data(Qt.DisplayRole))
                    for row in model_rows
                ]
                col_data = np.array(col_data, dtype=object, ndmin=2).T
                newdata[:, meta] = col_data

            if not len(newdata):
                newdata = None

            self.send("Data Subset", newdata)
        else:
            self.send("Data Subset", None)

    def rowFiltered(self, row):
        searchStrings = self.search_string.lower().split()
        row = " ".join(self.cells[row]).lower()
        return not all([s in row for s in searchStrings])

    def searchUpdate(self):
        if not self.data:
            return
        searchStrings = self.search_string.lower().split()
        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            row = " ".join(row).lower()
            self.treeWidget.setRowHidden(
                mapFromSource(index(i, 0)).row(), QModelIndex(),
                not all([s in row for s in searchStrings]))

    def selectFiltered(self):
        if not self.data:
            return
        itemSelection = QItemSelection()

        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            if not self.rowFiltered(i):
                itemSelection.select(mapFromSource(index(i, 0)),
                                     mapFromSource(index(i, 0)))
        self.treeWidget.selectionModel().select(
            itemSelection,
            QItemSelectionModel.Select | QItemSelectionModel.Rows)

    def updateDictyExpressLink(self, genes, show=False):
        def fix(ddb):
            if ddb.startswith("DDB"):
                if not ddb.startswith("DDB_G"):
                    ddb = ddb.replace("DDB", "DDB_G")
                return ddb
            return None

        if show:
            genes = [fix(gene) for gene in genes if fix(gene)]
            link1 = '<a href="http://dictyexpress.biolab.si/run/index.php?gene=%s">Microarray profile</a>'
            link2 = '<a href="http://dictyexpress.biolab.si/run/index.php?gene=%s&db=rnaseq">RNA-Seq profile</a>'
            self.linkLabel.setText(link1 + "<br/>" + link2)

            show = any(genes)

        if show:
            self.dictyExpressBox.show()
        else:
            self.dictyExpressBox.hide()

    def onDictyExpressLink(self, link):
        if not self.data:
            return

        selectedIndexes = self.treeWidget.selectedIndexes()
        if not len(selectedIndexes):
            QMessageBox.information(self, "No gene ids selected",
                                    "Please select some genes and try again.")
            return
        model = self.treeWidget.model()
        mapToSource = model.mapToSource
        selectedRows = self.treeWidget.selectedIndexes()
        selectedRows = [mapToSource(index).row() for index in selectedRows]
        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)

        def fix(ddb):
            if ddb.startswith("DDB"):
                if not ddb.startswith("DDB_G"):
                    ddb = ddb.replace("DDB", "DDB_G")
                return ddb
            return None

        genes = [fix(gene) for gene in selectedIds if fix(gene)]
        url = str(link) % " ".join(genes)
        QDesktopServices.openUrl(QUrl(url))

    def onAltSourceChange(self):
        self.updateInfoItems()

    def onDeleteWidget(self):
        # try to cancel pending tasks
        if self.initfuture:
            self.initfuture.cancel()
        if self.itemsfuture:
            self.itemsfuture.cancel()

        self.executor.shutdown(wait=False)
        super().onDeleteWidget()
Ejemplo n.º 33
0
class OWImportImages(widget.OWWidget):
    name = "Import Images"
    description = "Import images from a directory(s)"
    icon = "icons/ImportImages.svg"
    priority = 110

    outputs = [("Data", Orange.data.Table)]

    #: list of recent paths
    recent_paths = settings.Setting([])  # type: List[RecentPath]
    currentPath = settings.Setting(None)

    want_main_area = False
    resizing_enabled = False

    Modality = Qt.ApplicationModal
    # Modality = Qt.WindowModal

    MaxRecentItems = 20

    def __init__(self):
        super().__init__()
        #: widget's runtime state
        self.__state = State.NoState
        self.data = None
        self._n_image_categories = 0
        self._n_image_data = 0
        self._n_skipped = 0

        self.__invalidated = False
        self.__pendingTask = None

        vbox = gui.vBox(self.controlArea)
        hbox = gui.hBox(vbox)
        self.recent_cb = QComboBox(
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=16,
            acceptDrops=True)
        self.recent_cb.installEventFilter(self)
        self.recent_cb.activated[int].connect(self.__onRecentActivated)
        icons = standard_icons(self)

        browseaction = QAction(
            "Open/Load Images",
            self,
            iconText="\N{HORIZONTAL ELLIPSIS}",
            icon=icons.dir_open_icon,
            toolTip="Select a directory from which to load the images")
        browseaction.triggered.connect(self.__runOpenDialog)
        reloadaction = QAction("Reload",
                               self,
                               icon=icons.reload_icon,
                               toolTip="Reload current image set")
        reloadaction.triggered.connect(self.reload)
        self.__actions = namespace(
            browse=browseaction,
            reload=reloadaction,
        )

        browsebutton = QPushButton(browseaction.iconText(),
                                   icon=browseaction.icon(),
                                   toolTip=browseaction.toolTip(),
                                   clicked=browseaction.trigger)
        reloadbutton = QPushButton(
            reloadaction.iconText(),
            icon=reloadaction.icon(),
            clicked=reloadaction.trigger,
            default=True,
        )

        hbox.layout().addWidget(self.recent_cb)
        hbox.layout().addWidget(browsebutton)
        hbox.layout().addWidget(reloadbutton)

        self.addActions([browseaction, reloadaction])

        reloadaction.changed.connect(
            lambda: reloadbutton.setEnabled(reloadaction.isEnabled()))
        box = gui.vBox(vbox, "Info")
        self.infostack = QStackedWidget()

        self.info_area = QLabel(text="No image set selected", wordWrap=True)
        self.progress_widget = QProgressBar(minimum=0, maximum=0)
        self.cancel_button = QPushButton(
            "Cancel",
            icon=icons.cancel_icon,
        )
        self.cancel_button.clicked.connect(self.cancel)

        w = QWidget()
        vlayout = QVBoxLayout()
        vlayout.setContentsMargins(0, 0, 0, 0)
        hlayout = QHBoxLayout()
        hlayout.setContentsMargins(0, 0, 0, 0)

        hlayout.addWidget(self.progress_widget)
        hlayout.addWidget(self.cancel_button)
        vlayout.addLayout(hlayout)

        self.pathlabel = TextLabel()
        self.pathlabel.setTextElideMode(Qt.ElideMiddle)
        self.pathlabel.setAttribute(Qt.WA_MacSmallSize)

        vlayout.addWidget(self.pathlabel)
        w.setLayout(vlayout)

        self.infostack.addWidget(self.info_area)
        self.infostack.addWidget(w)

        box.layout().addWidget(self.infostack)

        self.__initRecentItemsModel()
        self.__invalidated = True
        self.__executor = ThreadExecutor(self)

        QApplication.postEvent(self, QEvent(RuntimeEvent.Init))

    def __initRecentItemsModel(self):
        if self.currentPath is not None and \
                not os.path.isdir(self.currentPath):
            self.currentPath = None

        recent_paths = []
        for item in self.recent_paths:
            if os.path.isdir(item.abspath):
                recent_paths.append(item)
        recent_paths = recent_paths[:OWImportImages.MaxRecentItems]
        recent_model = self.recent_cb.model()
        for pathitem in recent_paths:
            item = RecentPath_asqstandarditem(pathitem)
            recent_model.appendRow(item)

        self.recent_paths = recent_paths

        if self.currentPath is not None and \
                os.path.isdir(self.currentPath) and self.recent_paths and \
                os.path.samefile(self.currentPath, self.recent_paths[0].abspath):
            self.recent_cb.setCurrentIndex(0)
        else:
            self.currentPath = None
            self.recent_cb.setCurrentIndex(-1)
        self.__actions.reload.setEnabled(self.currentPath is not None)

    def customEvent(self, event):
        """Reimplemented."""
        if event.type() == RuntimeEvent.Init:
            if self.__invalidated:
                try:
                    self.start()
                finally:
                    self.__invalidated = False

        super().customEvent(event)

    def __runOpenDialog(self):
        startdir = os.path.expanduser("~/")
        if self.recent_paths:
            startdir = os.path.dirname(self.recent_paths[0].abspath)

        if OWImportImages.Modality == Qt.WindowModal:
            dlg = QFileDialog(
                self,
                "Select Top Level Directory",
                startdir,
                acceptMode=QFileDialog.AcceptOpen,
                modal=True,
            )
            dlg.setFileMode(QFileDialog.Directory)
            dlg.setOption(QFileDialog.ShowDirsOnly)
            dlg.setDirectory(startdir)
            dlg.setAttribute(Qt.WA_DeleteOnClose)

            @dlg.accepted.connect
            def on_accepted():
                dirpath = dlg.selectedFiles()
                if dirpath:
                    self.setCurrentPath(dirpath[0])
                    self.start()

            dlg.open()
        else:
            dirpath = QFileDialog.getExistingDirectory(
                self, "Select Top Level Directory", startdir)
            if dirpath:
                self.setCurrentPath(dirpath)
                self.start()

    def __onRecentActivated(self, index):
        item = self.recent_cb.itemData(index)
        if item is None:
            return
        assert isinstance(item, RecentPath)
        self.setCurrentPath(item.abspath)
        self.start()

    def __updateInfo(self):
        if self.__state == State.NoState:
            text = "No image set selected"
        elif self.__state == State.Processing:
            text = "Processing"
        elif self.__state == State.Done:
            nvalid = self._n_image_data
            ncategories = self._n_image_categories
            n_skipped = self._n_skipped
            if ncategories < 2:
                text = "{} image{}".format(nvalid, "s" if nvalid != 1 else "")
            else:
                text = "{} images / {} categories".format(nvalid, ncategories)
            if n_skipped > 0:
                text = text + ", {} skipped".format(n_skipped)
        elif self.__state == State.Cancelled:
            text = "Cancelled"
        elif self.__state == State.Error:
            text = "Error state"
        else:
            assert False

        self.info_area.setText(text)

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

    def setCurrentPath(self, path):
        """
        Set the current root image path to path

        If the path does not exists or is not a directory the current path
        is left unchanged

        Parameters
        ----------
        path : str
            New root import path.

        Returns
        -------
        status : bool
            True if the current root import path was successfully
            changed to path.
        """
        if self.currentPath is not None and path is not None and \
                os.path.isdir(self.currentPath) and os.path.isdir(path) and \
                os.path.samefile(self.currentPath, path):
            return True

        success = True
        error = None
        if path is not None:
            if not os.path.exists(path):
                error = "'{}' does not exist".format(path)
                path = None
                success = False
            elif not os.path.isdir(path):
                error = "'{}' is not a directory".format(path)
                path = None
                success = False

        if error is not None:
            self.error(error)
            warnings.warn(error, UserWarning, stacklevel=3)
        else:
            self.error()

        if path is not None:
            newindex = self.addRecentPath(path)
            self.recent_cb.setCurrentIndex(newindex)
            if newindex >= 0:
                self.currentPath = path
            else:
                self.currentPath = None
        else:
            self.currentPath = None
        self.__actions.reload.setEnabled(self.currentPath is not None)

        if self.__state == State.Processing:
            self.cancel()

        return success

    def addRecentPath(self, path):
        """
        Prepend a path entry to the list of recent paths

        If an entry with the same path already exists in the recent path
        list it is moved to the first place

        Parameters
        ----------
        path : str
        """
        existing = None
        for pathitem in self.recent_paths:
            try:
                if os.path.samefile(pathitem.abspath, path):
                    existing = pathitem
                    break
            except FileNotFoundError:
                # file not found if the `pathitem.abspath` no longer exists
                pass

        model = self.recent_cb.model()

        if existing is not None:
            selected_index = self.recent_paths.index(existing)
            assert model.item(selected_index).data(Qt.UserRole) is existing
            self.recent_paths.remove(existing)
            row = model.takeRow(selected_index)
            self.recent_paths.insert(0, existing)
            model.insertRow(0, row)
        else:
            item = RecentPath(path, None, None)
            self.recent_paths.insert(0, item)
            model.insertRow(0, RecentPath_asqstandarditem(item))
        return 0

    def __setRuntimeState(self, state):
        assert state in State
        self.setBlocking(state == State.Processing)
        message = ""
        if state == State.Processing:
            assert self.__state in [
                State.Done, State.NoState, State.Error, State.Cancelled
            ]
            message = "Processing"
        elif state == State.Done:
            assert self.__state == State.Processing
        elif state == State.Cancelled:
            assert self.__state == State.Processing
            message = "Cancelled"
        elif state == State.Error:
            message = "Error during processing"
        elif state == State.NoState:
            message = ""
        else:
            assert False

        self.__state = state

        if self.__state == State.Processing:
            self.infostack.setCurrentIndex(1)
        else:
            self.infostack.setCurrentIndex(0)

        self.setStatusMessage(message)
        self.__updateInfo()

    def reload(self):
        """
        Restart the image scan task
        """
        if self.__state == State.Processing:
            self.cancel()

        self.data = None
        self.start()

    def start(self):
        """
        Start/execute the image indexing operation
        """
        self.error()

        self.__invalidated = False
        if self.currentPath is None:
            return

        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            log.info("Starting a new task while one is in progress. "
                     "Cancel the existing task (dir:'{}')".format(
                         self.__pendingTask.startdir))
            self.cancel()

        startdir = self.currentPath

        self.__setRuntimeState(State.Processing)

        report_progress = methodinvoke(self, "__onReportProgress", (object, ))

        task = ImportImages(report_progress=report_progress)

        # collect the task state in one convenient place
        self.__pendingTask = taskstate = namespace(
            task=task,
            startdir=startdir,
            future=None,
            watcher=None,
            cancelled=False,
            cancel=None,
        )

        def cancel():
            # Cancel the task and disconnect
            if taskstate.future.cancel():
                pass
            else:
                taskstate.task.cancelled = True
                taskstate.cancelled = True
                try:
                    taskstate.future.result(timeout=3)
                except UserInterruptError:
                    pass
                except TimeoutError:
                    log.info("The task did not stop in in a timely manner")
            taskstate.watcher.finished.disconnect(self.__onRunFinished)

        taskstate.cancel = cancel

        def run_image_scan_task_interupt():
            try:
                return task(startdir)
            except UserInterruptError:
                # Suppress interrupt errors, so they are not logged
                return

        taskstate.future = self.__executor.submit(run_image_scan_task_interupt)
        taskstate.watcher = FutureWatcher(taskstate.future)
        taskstate.watcher.finished.connect(self.__onRunFinished)

    @Slot()
    def __onRunFinished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__state == State.Processing
        assert self.__pendingTask is not None
        assert self.sender() is self.__pendingTask.watcher
        assert self.__pendingTask.future.done()
        task = self.__pendingTask
        self.__pendingTask = None

        try:
            data, n_skipped = task.future.result()
        except Exception:
            sys.excepthook(*sys.exc_info())
            state = State.Error
            data = None
            n_skipped = 0
            self.error(traceback.format_exc())
        else:
            state = State.Done
            self.error()

        if data:
            self._n_image_data = len(data)
            self._n_image_categories = len(data.domain.class_var.values)\
                if data.domain.class_var else 0

        self.data = data
        self._n_skipped = n_skipped

        self.__setRuntimeState(state)
        self.commit()

    def cancel(self):
        """
        Cancel current pending task (if any).
        """
        if self.__state == State.Processing:
            assert self.__pendingTask is not None
            self.__pendingTask.cancel()
            self.__pendingTask = None
            self.__setRuntimeState(State.Cancelled)

    @Slot(object)
    def __onReportProgress(self, arg):
        # report on scan progress from a worker thread
        # arg must be a namespace(count: int, lastpath: str)
        assert QThread.currentThread() is self.thread()
        if self.__state == State.Processing:
            self.pathlabel.setText(prettyfypath(arg.lastpath))

    def commit(self):
        """
        Commit a Table from the collected image meta data.
        """
        self.send("Data", self.data)

    def onDeleteWidget(self):
        self.cancel()
        self.__executor.shutdown(wait=True)
        self.__invalidated = False

    def eventFilter(self, receiver, event):
        # re-implemented from QWidget
        # intercept and process drag drop events on the recent directory
        # selection combo box
        def dirpath(event):
            # type: (QDropEvent) -> Optional[str]
            """Return the directory from a QDropEvent."""
            data = event.mimeData()
            urls = data.urls()
            if len(urls) == 1:
                url = urls[0]
                path = url.toLocalFile()
                if os.path.isdir(path):
                    return path
            return None

        if receiver is self.recent_cb and \
                event.type() in {QEvent.DragEnter, QEvent.DragMove,
                                 QEvent.Drop}:
            assert isinstance(event, QDropEvent)
            path = dirpath(event)
            if path is not None and event.possibleActions() & Qt.LinkAction:
                event.setDropAction(Qt.LinkAction)
                event.accept()
                if event.type() == QEvent.Drop:
                    self.setCurrentPath(path)
                    self.start()
            else:
                event.ignore()
            return True

        return super().eventFilter(receiver, event)
class OWImportSamples(OWWidget):
    name = "Import Samples"
    icon = "icons/import.svg"
    want_main_area = False
    resizing_enabled = False
    priority = 1
    outputs = [("Data", Table)]

    username = settings.Setting('')
    password = settings.Setting('')
    selected_server = settings.Setting(0)
    combo_items = settings.Setting([])

    def __init__(self):
        super().__init__()
        self.res = None
        self.data = None
        self._datatask = None
        self._executor = ThreadExecutor()
        """Choose server"""
        box = gui.widgetBox(self.controlArea, 'Server')
        box.setSizePolicy(Policy.Minimum, Policy.Fixed)
        self.servers = gui.comboBox(box,
                                    self,
                                    "selected_server",
                                    editable=True,
                                    items=self.combo_items,
                                    callback=self.on_server_changed)
        """set credentials"""
        box = gui.widgetBox(self.controlArea, 'Credentials')
        box.setSizePolicy(Policy.Minimum, Policy.Fixed)

        self.name_field = gui.lineEdit(box,
                                       self,
                                       "username",
                                       "Username:"******"password",
                                       "Password:"******"""display info"""
        box = gui.vBox(self.controlArea, "Info")
        box.setSizePolicy(Policy.Minimum, Policy.Fixed)
        self.info = gui.widgetLabel(box, 'No data loaded.')

        gui.rubber(self.controlArea)
        self.auth_set()

        if self.username and self.password:
            self.connect()

    def _on_exception(self):
        self._update_info(error_msg='Error while downloading data...\n'
                          'Please check your connection.')

    def _update_info(self, error_msg=None):
        if not error_msg:
            if self._datatask is not None:
                if not self._datatask.future().done():
                    self.info.setText('Retrieving data...')
                    self._handle_inputs(False)
            if self.data:
                self._handle_inputs(True)
                self.info.setText('Data ready: {} samples loaded.'.format(
                    len(self.data)))
        else:
            self.info.setText(error_msg)

    def _handle_inputs(self, enable):
        self.name_field.setEnabled(enable)
        self.pass_field.setEnabled(enable)
        self.servers.setEnabled(enable)

    def _handle_styles(self,
                       login=False,
                       server=False,
                       user=False,
                       passwd=False):
        if login:
            self.name_field.setFocus()
            self.name_field.setStyleSheet(error_red)
            self.pass_field.setStyleSheet(error_red)
        elif server:
            self.servers.setFocus()
            self.servers.setStyleSheet(error_red)
        elif user:
            self.name_field.setFocus()
            self.name_field.setStyleSheet(error_red)
        elif passwd:
            self.pass_field.setFocus()
            self.pass_field.setStyleSheet(error_red)

    def _reset_styles(self):
        self.pass_field.setStyleSheet('')
        self.name_field.setStyleSheet('')
        self.servers.setStyleSheet('')

    def on_server_changed(self):
        if self.servers.itemText(self.selected_server) != '':
            if self.username and self.password:
                self.connect()
            elif not self.username:
                self._handle_styles(user=True)
            elif not self.password:
                self._handle_styles(passwd=True)
        else:
            self._handle_styles(server=True)

    def auth_set(self):
        self.pass_field.setDisabled(not self.username)
        self.pass_field.setFocus()

    def auth_changed(self):
        self.auth_set()
        if self.servers.itemText(self.selected_server) == '':
            self._handle_styles(server=True)
        else:
            self._reset_styles()
            self.connect()

    def commit(self):
        self.data = self._datatask.result()
        self._datatask = None
        self._update_info()
        if self.data:
            self.send("Data", to_orange_table(self.data))

    def connect(self):
        self.res = None
        self.data = None

        if self.username and self.password:
            self._reset_styles()
            """Store widget settings (Login)"""
            self.combo_items = [
                self.servers.itemText(i) for i in range(self.servers.count())
            ]
            self.selected_server = self.servers.currentIndex()

            try:
                self.res = ResolweAPI(
                    self.username, self.password,
                    self.servers.itemText(self.selected_server))
            except (ResolweCredentialsException, ResolweServerException,
                    Exception) as e:
                error_name = type(e).__name__

                if error_name == 'ResolweCredentialsException':
                    self._update_info(error_msg=str(e))
                    self._handle_styles(login=True)
                elif error_name == 'ResolweServerException' or error_name == 'MissingSchema':
                    self._update_info(error_msg=str(e))
                    self._handle_styles(server=True)
                else:
                    self._update_info(error_msg=str(e))

            if self.res:
                self._datatask = DownloadTask(self.res)
                self._datatask.finished.connect(self.commit)
                self._datatask.exception.connect(self._on_exception)
                self._executor.submit(self._datatask)
                self._update_info()

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._executor.shutdown(wait=False)
Ejemplo n.º 35
0
class OWTestLearners(OWWidget):
    name = "Test & Score"
    description = "Cross-validation accuracy estimation."
    icon = "icons/TestLearners1.svg"
    priority = 100

    class Inputs:
        train_data = Input("Data", Table, default=True)
        test_data = Input("Test Data", Table)
        learner = Input("Learner", Learner, multiple=True)
        preprocessor = Input("Preprocessor", Preprocess)

    class Outputs:
        predictions = Output("Predictions", Table)
        evaluations_results = Output("Evaluation Results", Results)

    settings_version = 3
    UserAdviceMessages = [
        widget.Message("Click on the table header to select shown columns",
                       "click_header")
    ]

    settingsHandler = settings.PerfectDomainContextHandler()

    #: Resampling/testing types
    KFold, FeatureFold, ShuffleSplit, LeaveOneOut, TestOnTrain, TestOnTest \
        = 0, 1, 2, 3, 4, 5
    #: Numbers of folds
    NFolds = [2, 3, 5, 10, 20]
    #: Number of repetitions
    NRepeats = [2, 3, 5, 10, 20, 50, 100]
    #: Sample sizes
    SampleSizes = [5, 10, 20, 25, 30, 33, 40, 50, 60, 66, 70, 75, 80, 90, 95]

    #: Selected resampling type
    resampling = settings.Setting(0)
    #: Number of folds for K-fold cross validation
    n_folds = settings.Setting(3)
    #: Stratified sampling for K-fold
    cv_stratified = settings.Setting(True)
    #: Number of repeats for ShuffleSplit sampling
    n_repeats = settings.Setting(3)
    #: ShuffleSplit sample size
    sample_size = settings.Setting(9)
    #: Stratified sampling for Random Sampling
    shuffle_stratified = settings.Setting(True)
    # CV where nr. of feature values determines nr. of folds
    fold_feature = settings.ContextSetting(None)
    fold_feature_selected = settings.ContextSetting(False)

    TARGET_AVERAGE = "(Average over classes)"
    class_selection = settings.ContextSetting(TARGET_AVERAGE)

    BUILTIN_ORDER = {
        DiscreteVariable: ("AUC", "CA", "F1", "Precision", "Recall"),
        ContinuousVariable: ("MSE", "RMSE", "MAE", "R2")
    }

    shown_scores = \
        settings.Setting(set(chain(*BUILTIN_ORDER.values())))

    class Error(OWWidget.Error):
        train_data_empty = Msg("Train data set is empty.")
        test_data_empty = Msg("Test data set is empty.")
        class_required = Msg("Train data input requires a target variable.")
        too_many_classes = Msg("Too many target variables.")
        class_required_test = Msg(
            "Test data input requires a target variable.")
        too_many_folds = Msg("Number of folds exceeds the data size")
        class_inconsistent = Msg("Test and train data sets "
                                 "have different target variables.")
        memory_error = Msg("Not enough memory.")
        only_one_class_var_value = Msg("Target variable has only one value.")

    class Warning(OWWidget.Warning):
        missing_data = \
            Msg("Instances with unknown target values were removed from{}data.")
        test_data_missing = Msg("Missing separate test data input.")
        scores_not_computed = Msg("Some scores could not be computed.")
        test_data_unused = Msg("Test data is present but unused. "
                               "Select 'Test on test data' to use it.")

    class Information(OWWidget.Information):
        data_sampled = Msg("Train data has been sampled")
        test_data_sampled = Msg("Test data has been sampled")

    def __init__(self):
        super().__init__()

        self.data = None
        self.test_data = None
        self.preprocessor = None
        self.train_data_missing_vals = False
        self.test_data_missing_vals = False
        self.scorers = []

        #: An Ordered dictionary with current inputs and their testing results.
        self.learners = OrderedDict()  # type: Dict[Any, Input]

        self.__state = State.Waiting
        # Do we need to [re]test any learners, set by _invalidate and
        # cleared by __update
        self.__needupdate = False
        self.__task = None  # type: Optional[Task]
        self.__executor = ThreadExecutor()

        sbox = gui.vBox(self.controlArea, "Sampling")
        rbox = gui.radioButtons(sbox,
                                self,
                                "resampling",
                                callback=self._param_changed)

        gui.appendRadioButton(rbox, "Cross validation")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_folds",
                     label="Number of folds: ",
                     items=[str(x) for x in self.NFolds],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.kfold_changed)
        gui.checkBox(ibox,
                     self,
                     "cv_stratified",
                     "Stratified",
                     callback=self.kfold_changed)
        gui.appendRadioButton(rbox, "Cross validation by feature")
        ibox = gui.indentedBox(rbox)
        self.feature_model = DomainModel(order=DomainModel.METAS,
                                         valid_types=DiscreteVariable)
        self.features_combo = gui.comboBox(ibox,
                                           self,
                                           "fold_feature",
                                           model=self.feature_model,
                                           orientation=Qt.Horizontal,
                                           callback=self.fold_feature_changed)

        gui.appendRadioButton(rbox, "Random sampling")
        ibox = gui.indentedBox(rbox)
        gui.comboBox(ibox,
                     self,
                     "n_repeats",
                     label="Repeat train/test: ",
                     items=[str(x) for x in self.NRepeats],
                     maximumContentsLength=3,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.comboBox(ibox,
                     self,
                     "sample_size",
                     label="Training set size: ",
                     items=["{} %".format(x) for x in self.SampleSizes],
                     maximumContentsLength=5,
                     orientation=Qt.Horizontal,
                     callback=self.shuffle_split_changed)
        gui.checkBox(ibox,
                     self,
                     "shuffle_stratified",
                     "Stratified",
                     callback=self.shuffle_split_changed)

        gui.appendRadioButton(rbox, "Leave one out")

        gui.appendRadioButton(rbox, "Test on train data")
        gui.appendRadioButton(rbox, "Test on test data")

        self.cbox = gui.vBox(self.controlArea, "Target Class")
        self.class_selection_combo = gui.comboBox(
            self.cbox,
            self,
            "class_selection",
            items=[],
            sendSelectedValue=True,
            valueType=str,
            callback=self._on_target_class_changed,
            contentsLength=8)

        gui.rubber(self.controlArea)

        self.view = gui.TableView(wordWrap=True, )
        header = self.view.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)
        header.setDefaultAlignment(Qt.AlignCenter)
        header.setStretchLastSection(False)
        header.setContextMenuPolicy(Qt.CustomContextMenu)
        header.customContextMenuRequested.connect(self.show_column_chooser)

        self.result_model = QStandardItemModel(self)
        self.result_model.setHorizontalHeaderLabels(["Method"])
        self.view.setModel(self.result_model)
        self.view.setItemDelegate(ItemDelegate())

        box = gui.vBox(self.mainArea, "Evaluation Results")
        box.layout().addWidget(self.view)

    def sizeHint(self):
        return QSize(780, 1)

    def _update_controls(self):
        self.fold_feature = None
        self.feature_model.set_domain(None)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.fold_feature is None and self.feature_model:
                self.fold_feature = self.feature_model[0]
        enabled = bool(self.feature_model)
        self.controls.resampling.buttons[
            OWTestLearners.FeatureFold].setEnabled(enabled)
        self.features_combo.setEnabled(enabled)
        if self.resampling == OWTestLearners.FeatureFold and not enabled:
            self.resampling = OWTestLearners.KFold

    @Inputs.learner
    def set_learner(self, learner, key):
        """
        Set the input `learner` for `key`.

        Parameters
        ----------
        learner : Optional[Orange.base.Learner]
        key : Any
        """
        if key in self.learners and learner is None:
            # Removed
            self._invalidate([key])
            del self.learners[key]
        else:
            self.learners[key] = InputLearner(learner, None, None)
            self._invalidate([key])

    @Inputs.train_data
    def set_train_data(self, data):
        """
        Set the input training dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.data_sampled.clear()
        self.Error.train_data_empty.clear()
        self.Error.class_required.clear()
        self.Error.too_many_classes.clear()
        self.Error.only_one_class_var_value.clear()
        if data is not None and not len(data):
            self.Error.train_data_empty()
            data = None
        if data:
            conds = [
                not data.domain.class_vars,
                len(data.domain.class_vars) > 1, data.domain.has_discrete_class
                and len(data.domain.class_var.values) == 1
            ]
            errors = [
                self.Error.class_required, self.Error.too_many_classes,
                self.Error.only_one_class_var_value
            ]
            for cond, error in zip(conds, errors):
                if cond:
                    error()
                    data = None
                    break

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.train_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.data = data
        self.closeContext()
        self._update_scorers()
        self._update_controls()
        if data is not None:
            self._update_class_selection()
            self.openContext(data.domain)
            if self.fold_feature_selected and bool(self.feature_model):
                self.resampling = OWTestLearners.FeatureFold
        self._invalidate()

    @Inputs.test_data
    def set_test_data(self, data):
        # type: (Orange.data.Table) -> None
        """
        Set the input separate testing dataset.

        Parameters
        ----------
        data : Optional[Orange.data.Table]
        """
        self.Information.test_data_sampled.clear()
        self.Error.test_data_empty.clear()
        if data is not None and not len(data):
            self.Error.test_data_empty()
            data = None
        if data and not data.domain.class_var:
            self.Error.class_required_test()
            data = None
        else:
            self.Error.class_required_test.clear()

        if isinstance(data, SqlTable):
            if data.approx_len() < AUTO_DL_LIMIT:
                data = Table(data)
            else:
                self.Information.test_data_sampled()
                data_sample = data.sample_time(1, no_cache=True)
                data_sample.download_data(AUTO_DL_LIMIT, partial=True)
                data = Table(data_sample)

        self.test_data_missing_vals = \
            data is not None and np.isnan(data.Y).any()
        if self.train_data_missing_vals or self.test_data_missing_vals:
            self.Warning.missing_data(self._which_missing_data())
            if data:
                data = HasClass()(data)
        else:
            self.Warning.missing_data.clear()

        self.test_data = data
        if self.resampling == OWTestLearners.TestOnTest:
            self._invalidate()

    def _which_missing_data(self):
        return {
            (True, True): " ",  # both, don't specify
            (True, False): " train ",
            (False, True): " test "
        }[(self.train_data_missing_vals, self.test_data_missing_vals)]

    # List of scorers shouldn't be retrieved globally, when the module is
    # loading since add-ons could have registered additional scorers.
    # It could have been cached but
    # - we don't gain much with it
    # - it complicates the unit tests
    def _update_scorers(self):
        if self.data is None or self.data.domain.class_var is None:
            self.scorers = []
            return
        class_var = self.data and self.data.domain.class_var
        order = {
            name: i
            for i, name in enumerate(self.BUILTIN_ORDER[type(class_var)])
        }
        # 'abstract' is retrieved from __dict__ to avoid inheriting
        usable = (cls for cls in scoring.Score.registry.values()
                  if cls.is_scalar and not cls.__dict__.get("abstract")
                  and isinstance(class_var, cls.class_types))
        self.scorers = sorted(usable, key=lambda cls: order.get(cls.name, 99))

    @Inputs.preprocessor
    def set_preprocessor(self, preproc):
        """
        Set the input preprocessor to apply on the training data.
        """
        self.preprocessor = preproc
        self._invalidate()

    def handleNewSignals(self):
        """Reimplemented from OWWidget.handleNewSignals."""
        self._update_class_selection()
        self._update_header()
        self._update_stats_model()
        if self.__needupdate:
            self.__update()

    def kfold_changed(self):
        self.resampling = OWTestLearners.KFold
        self._param_changed()

    def fold_feature_changed(self):
        self.resampling = OWTestLearners.FeatureFold
        self._param_changed()

    def shuffle_split_changed(self):
        self.resampling = OWTestLearners.ShuffleSplit
        self._param_changed()

    def _param_changed(self):
        self._invalidate()
        self.__update()

    def _update_header(self):
        # Set the correct horizontal header labels on the results_model.
        model = self.result_model
        model.setColumnCount(1 + len(self.scorers))
        for col, score in enumerate(self.scorers):
            item = QStandardItem(score.name)
            item.setToolTip(score.long_name)
            model.setHorizontalHeaderItem(col + 1, item)
        self._update_shown_columns()

    def _update_shown_columns(self):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            header.setSectionHidden(section, col_name not in self.shown_scores)

    def _update_stats_model(self):
        # Update the results_model with up to date scores.
        # Note: The target class specific scores (if requested) are
        # computed as needed in this method.
        model = self.view.model()
        # clear the table model, but preserving the header labels
        for r in reversed(range(model.rowCount())):
            model.takeRow(r)

        target_index = None
        if self.data is not None:
            class_var = self.data.domain.class_var
            if self.data.domain.has_discrete_class and \
                            self.class_selection != self.TARGET_AVERAGE:
                target_index = class_var.values.index(self.class_selection)
        else:
            class_var = None

        errors = []
        has_missing_scores = False

        for key, slot in self.learners.items():
            name = learner_name(slot.learner)
            head = QStandardItem(name)
            head.setData(key, Qt.UserRole)
            if isinstance(slot.results, Try.Fail):
                head.setToolTip(str(slot.results.exception))
                head.setText("{} (error)".format(name))
                head.setForeground(QtGui.QBrush(Qt.red))
                errors.append("{name} failed with error:\n"
                              "{exc.__class__.__name__}: {exc!s}".format(
                                  name=name, exc=slot.results.exception))

            row = [head]

            if class_var is not None and class_var.is_discrete and \
                    target_index is not None:
                if slot.results is not None and slot.results.success:
                    ovr_results = results_one_vs_rest(slot.results.value,
                                                      target_index)

                    # Cell variable is used immediatelly, it's not stored
                    # pylint: disable=cell-var-from-loop
                    stats = [
                        Try(scorer_caller(scorer, ovr_results))
                        for scorer in self.scorers
                    ]
                else:
                    stats = None
            else:
                stats = slot.stats

            if stats is not None:
                for stat in stats:
                    item = QStandardItem()
                    if stat.success:
                        item.setText("{:.3f}".format(stat.value[0]))
                    else:
                        item.setToolTip(str(stat.exception))
                        has_missing_scores = True
                    row.append(item)

            model.appendRow(row)

        self.error("\n".join(errors), shown=bool(errors))
        self.Warning.scores_not_computed(shown=has_missing_scores)

    def _update_class_selection(self):
        self.class_selection_combo.setCurrentIndex(-1)
        self.class_selection_combo.clear()
        if not self.data:
            return

        if self.data.domain.has_discrete_class:
            self.cbox.setVisible(True)
            class_var = self.data.domain.class_var
            items = [self.TARGET_AVERAGE] + class_var.values
            self.class_selection_combo.addItems(items)

            class_index = 0
            if self.class_selection in class_var.values:
                class_index = class_var.values.index(self.class_selection) + 1

            self.class_selection_combo.setCurrentIndex(class_index)
            self.class_selection = items[class_index]
        else:
            self.cbox.setVisible(False)

    def _on_target_class_changed(self):
        self._update_stats_model()

    def _invalidate(self, which=None):
        self.fold_feature_selected = \
            self.resampling == OWTestLearners.FeatureFold
        # Invalidate learner results for `which` input keys
        # (if None then all learner results are invalidated)
        if which is None:
            which = self.learners.keys()

        model = self.view.model()
        statmodelkeys = [
            model.item(row, 0).data(Qt.UserRole)
            for row in range(model.rowCount())
        ]

        for key in which:
            self.learners[key] = \
                self.learners[key]._replace(results=None, stats=None)

            if key in statmodelkeys:
                row = statmodelkeys.index(key)
                for c in range(1, model.columnCount()):
                    item = model.item(row, c)
                    if item is not None:
                        item.setData(None, Qt.DisplayRole)
                        item.setData(None, Qt.ToolTipRole)

        self.__needupdate = True

    def show_column_chooser(self, pos):
        # pylint doesn't know that self.shown_scores is a set, not a Setting
        # pylint: disable=unsupported-membership-test
        def update(col_name, checked):
            if checked:
                self.shown_scores.add(col_name)
            else:
                self.shown_scores.remove(col_name)
            self._update_shown_columns()

        menu = QMenu()
        model = self.result_model
        header = self.view.horizontalHeader()
        for section in range(1, model.columnCount()):
            col_name = model.horizontalHeaderItem(section).data(Qt.DisplayRole)
            action = menu.addAction(col_name)
            action.setCheckable(True)
            action.setChecked(col_name in self.shown_scores)
            action.triggered.connect(partial(update, col_name))
        menu.exec(header.mapToGlobal(pos))

    def commit(self):
        """
        Commit the results to output.
        """
        self.Error.memory_error.clear()
        valid = [
            slot for slot in self.learners.values()
            if slot.results is not None and slot.results.success
        ]
        combined = None
        predictions = None
        if valid:
            # Evaluation results
            combined = results_merge([slot.results.value for slot in valid])
            combined.learner_names = [
                learner_name(slot.learner) for slot in valid
            ]

            # Predictions & Probabilities
            try:
                predictions = combined.get_augmented_data(
                    combined.learner_names)
            except MemoryError:
                self.Error.memory_error()

        self.Outputs.evaluations_results.send(combined)
        self.Outputs.predictions.send(predictions)

    def send_report(self):
        """Report on the testing schema and results"""
        if not self.data or not self.learners:
            return
        if self.resampling == self.KFold:
            stratified = 'Stratified ' if self.cv_stratified else ''
            items = [("Sampling type", "{}{}-fold Cross validation".format(
                stratified, self.NFolds[self.n_folds]))]
        elif self.resampling == self.LeaveOneOut:
            items = [("Sampling type", "Leave one out")]
        elif self.resampling == self.ShuffleSplit:
            stratified = 'Stratified ' if self.shuffle_stratified else ''
            items = [
                ("Sampling type",
                 "{}Shuffle split, {} random samples with {}% data ".format(
                     stratified, self.NRepeats[self.n_repeats],
                     self.SampleSizes[self.sample_size]))
            ]
        elif self.resampling == self.TestOnTrain:
            items = [("Sampling type", "No sampling, test on training data")]
        elif self.resampling == self.TestOnTest:
            items = [("Sampling type", "No sampling, test on testing data")]
        else:
            items = []
        if self.data.domain.has_discrete_class:
            items += [("Target class", self.class_selection.strip("()"))]
        if items:
            self.report_items("Settings", items)
        self.report_table("Scores", self.view)

    @classmethod
    def migrate_settings(cls, settings_, version):
        if version < 2:
            if settings_["resampling"] > 0:
                settings_["resampling"] += 1
        if version < 3:
            # Older version used an incompatible context handler
            settings_["context_settings"] = [
                c for c in settings_.get("context_settings", ())
                if not hasattr(c, 'classes')
            ]

    @Slot(float)
    def setProgressValue(self, value):
        self.progressBarSet(value, processEvents=False)

    def __update(self):
        self.__needupdate = False

        assert self.__task is None or self.__state == State.Running
        if self.__state == State.Running:
            self.cancel()

        self.Warning.test_data_unused.clear()
        self.Warning.test_data_missing.clear()
        self.warning()
        self.Error.class_inconsistent.clear()
        self.Error.too_many_folds.clear()
        self.error()

        # check preconditions and return early
        if self.data is None:
            self.__state = State.Waiting
            self.commit()
            return
        if not self.learners:
            self.__state = State.Waiting
            self.commit()
            return
        if self.resampling == OWTestLearners.KFold and \
                len(self.data) < self.NFolds[self.n_folds]:
            self.Error.too_many_folds()
            self.__state = State.Waiting
            self.commit()
            return

        elif self.resampling == OWTestLearners.TestOnTest:
            if self.test_data is None:
                if not self.Error.test_data_empty.is_shown():
                    self.Warning.test_data_missing()
                self.__state = State.Waiting
                self.commit()
                return
            elif self.test_data.domain.class_var != self.data.domain.class_var:
                self.Error.class_inconsistent()
                self.__state = State.Waiting
                self.commit()
                return

        elif self.test_data is not None:
            self.Warning.test_data_unused()

        rstate = 42
        common_args = dict(
            store_data=True,
            preprocessor=self.preprocessor,
        )
        # items in need of an update
        items = [(key, slot) for key, slot in self.learners.items()
                 if slot.results is None]
        learners = [slot.learner for _, slot in items]

        # deepcopy all learners as they are not thread safe (by virtue of
        # the base API). These will be the effective learner objects tested
        # but will be replaced with the originals on return (see restore
        # learners bellow)
        learners_c = [copy.deepcopy(learner) for learner in learners]

        if self.resampling == OWTestLearners.KFold:
            folds = self.NFolds[self.n_folds]
            test_f = partial(Orange.evaluation.CrossValidation,
                             self.data,
                             learners_c,
                             k=folds,
                             random_state=rstate,
                             **common_args)
        elif self.resampling == OWTestLearners.FeatureFold:
            test_f = partial(Orange.evaluation.CrossValidationFeature,
                             self.data, learners_c, self.fold_feature,
                             **common_args)
        elif self.resampling == OWTestLearners.LeaveOneOut:
            test_f = partial(Orange.evaluation.LeaveOneOut, self.data,
                             learners_c, **common_args)
        elif self.resampling == OWTestLearners.ShuffleSplit:
            train_size = self.SampleSizes[self.sample_size] / 100
            test_f = partial(Orange.evaluation.ShuffleSplit,
                             self.data,
                             learners_c,
                             n_resamples=self.NRepeats[self.n_repeats],
                             train_size=train_size,
                             test_size=None,
                             stratified=self.shuffle_stratified,
                             random_state=rstate,
                             **common_args)
        elif self.resampling == OWTestLearners.TestOnTrain:
            test_f = partial(Orange.evaluation.TestOnTrainingData, self.data,
                             learners_c, **common_args)
        elif self.resampling == OWTestLearners.TestOnTest:
            test_f = partial(Orange.evaluation.TestOnTestData, self.data,
                             self.test_data, learners_c, **common_args)
        else:
            assert False, "self.resampling %s" % self.resampling

        def replace_learners(evalfunc, *args, **kwargs):
            res = evalfunc(*args, **kwargs)
            assert all(lc is lo for lc, lo in zip(learners_c, res.learners))
            res.learners[:] = learners
            return res

        test_f = partial(replace_learners, test_f)

        self.__submit(test_f)

    def __submit(self, testfunc):
        # type: (Callable[[Callable[float]], Results]) -> None
        """
        Submit a testing function for evaluation

        MUST not be called if an evaluation is already pending/running.
        Cancel the existing task first.

        Parameters
        ----------
        testfunc : Callable[[Callable[float]], Results])
            Must be a callable taking a single `callback` argument and
            returning a Results instance
        """
        assert self.__state != State.Running
        # Setup the task
        task = Task()

        def progress_callback(finished):
            if task.cancelled:
                raise UserInterrupt()
            QMetaObject.invokeMethod(self, "setProgressValue",
                                     Qt.QueuedConnection,
                                     Q_ARG(float, 100 * finished))

        def ondone(_):
            QMetaObject.invokeMethod(self, "__task_complete",
                                     Qt.QueuedConnection, Q_ARG(object, task))

        testfunc = partial(testfunc, callback=progress_callback)
        task.future = self.__executor.submit(testfunc)
        task.future.add_done_callback(ondone)

        self.progressBarInit(processEvents=None)
        self.setBlocking(True)
        self.setStatusMessage("Running")

        self.__state = State.Running
        self.__task = task

    @Slot(object)
    def __task_complete(self, task):
        # handle a completed task
        assert self.thread() is QThread.currentThread()
        if self.__task is not task:
            assert task.cancelled
            log.debug("Reaping cancelled task: %r", "<>")
            return

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)
        self.setStatusMessage("")
        result = task.future
        assert result.done()
        self.__task = None
        try:
            results = result.result()  # type: Results
            learners = results.learners  # type: List[Learner]
        except Exception as er:
            log.exception("testing error (in __task_complete):", exc_info=True)
            self.error("\n".join(traceback.format_exception_only(type(er),
                                                                 er)))
            self.__state = State.Done
            return

        self.__state = State.Done

        learner_key = {
            slot.learner: key
            for key, slot in self.learners.items()
        }
        assert all(learner in learner_key for learner in learners)

        # Update the results for individual learners
        class_var = results.domain.class_var
        for learner, result in zip(learners, results.split_by_model()):
            stats = None
            if class_var.is_primitive():
                ex = result.failed[0]
                if ex:
                    stats = [Try.Fail(ex)] * len(self.scorers)
                    result = Try.Fail(ex)
                else:
                    stats = [
                        Try(scorer_caller(scorer, result))
                        for scorer in self.scorers
                    ]
                    result = Try.Success(result)
            key = learner_key.get(learner)
            self.learners[key] = \
                self.learners[key]._replace(results=result, stats=stats)

        self._update_header()
        self._update_stats_model()

        self.commit()

    def cancel(self):
        """
        Cancel the current/pending evaluation (if any).
        """
        if self.__task is not None:
            assert self.__state == State.Running
            self.__state = State.Cancelled
            task, self.__task = self.__task, None
            task.cancel()
            assert task.future.done()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Ejemplo n.º 36
0
class OWLearningCurveC(widget.OWWidget):
    name = "Learning Curve (C)"
    description = ("Takes a dataset and a set of learners and shows a "
                   "learning curve in a table")
    icon = "icons/LearningCurve.svg"
    priority = 1010

    inputs = [("Data", Orange.data.Table, "set_dataset", widget.Default),
              ("Test Data", Orange.data.Table, "set_testdataset"),
              ("Learner", Orange.classification.Learner, "set_learner",
               widget.Multiple + widget.Default)]

    #: cross validation folds
    folds = settings.Setting(5)
    #: points in the learning curve
    steps = settings.Setting(10)
    #: index of the selected scoring function
    scoringF = settings.Setting(0)
    #: compute curve on any change of parameters
    commitOnChange = settings.Setting(True)

    def __init__(self):
        super().__init__()

        # sets self.curvePoints, self.steps equidistant points from
        # 1/self.steps to 1
        self.updateCurvePoints()

        self.scoring = [
            ("Classification Accuracy", Orange.evaluation.scoring.CA),
            ("AUC", Orange.evaluation.scoring.AUC),
            ("Precision", Orange.evaluation.scoring.Precision),
            ("Recall", Orange.evaluation.scoring.Recall)
        ]
        #: input data on which to construct the learning curve
        self.data = None
        #: optional test data
        self.testdata = None
        #: A {input_id: Learner} mapping of current learners from input channel
        self.learners = OrderedDict()
        #: A {input_id: List[Results]} mapping of input id to evaluation
        #: results list, one for each curve point
        self.results = OrderedDict()
        #: A {input_id: List[float]} mapping of input id to learning curve
        #: point scores
        self.curves = OrderedDict()

        # [start-snippet-3]
        #: The current evaluating task (if any)
        self._task = None   # type: Optional[Task]
        #: An executor we use to submit learner evaluations into a thread pool
        self._executor = ThreadExecutor()
        # [end-snippet-3]

        # GUI
        box = gui.widgetBox(self.controlArea, "Info")
        self.infoa = gui.widgetLabel(box, 'No data on input.')
        self.infob = gui.widgetLabel(box, 'No learners.')

        gui.separator(self.controlArea)

        box = gui.widgetBox(self.controlArea, "Evaluation Scores")
        gui.comboBox(box, self, "scoringF",
                     items=[x[0] for x in self.scoring],
                     callback=self._invalidate_curves)

        gui.separator(self.controlArea)

        box = gui.widgetBox(self.controlArea, "Options")
        gui.spin(box, self, 'folds', 2, 100, step=1,
                 label='Cross validation folds:  ', keyboardTracking=False,
                 callback=lambda:
                    self._invalidate_results() if self.commitOnChange else None
        )
        gui.spin(box, self, 'steps', 2, 100, step=1,
                 label='Learning curve points:  ', keyboardTracking=False,
                 callback=[self.updateCurvePoints,
                           lambda: self._invalidate_results() if self.commitOnChange else None])
        gui.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
        self.commitBtn = gui.button(box, self, "Apply Setting",
                                    callback=self._invalidate_results,
                                    disabled=True)

        gui.rubber(self.controlArea)

        # table widget
        self.table = gui.table(self.mainArea,
                               selectionMode=QTableWidget.NoSelection)

    ##########################################################################
    # slots: handle input signals

    def set_dataset(self, data):
        """Set the input train dataset."""
        # Clear all results/scores
        for id in list(self.results):
            self.results[id] = None
        for id in list(self.curves):
            self.curves[id] = None

        self.data = data

        if data is not None:
            self.infoa.setText('%d instances in input dataset' % len(data))
        else:
            self.infoa.setText('No data on input.')

        self.commitBtn.setEnabled(self.data is not None)

    def set_testdataset(self, testdata):
        """Set a separate test dataset."""
        # Clear all results/scores
        for id in list(self.results):
            self.results[id] = None
        for id in list(self.curves):
            self.curves[id] = None

        self.testdata = testdata

    def set_learner(self, learner, id):
        """Set the input learner for channel id."""
        if id in self.learners:
            if learner is None:
                # remove a learner and corresponding results
                del self.learners[id]
                del self.results[id]
                del self.curves[id]
            else:
                # update/replace a learner on a previously connected link
                self.learners[id] = learner
                # invalidate the cross-validation results and curve scores
                # (will be computed/updated in `_update`)
                self.results[id] = None
                self.curves[id] = None
        else:
            if learner is not None:
                self.learners[id] = learner
                # initialize the cross-validation results and curve scores
                # (will be computed/updated in `_update`)
                self.results[id] = None
                self.curves[id] = None

        if len(self.learners):
            self.infob.setText("%d learners on input." % len(self.learners))
        else:
            self.infob.setText("No learners.")

        self.commitBtn.setEnabled(len(self.learners))

# [start-snippet-4]
    def handleNewSignals(self):
        self._update()
# [end-snippet-4]

    def _invalidate_curves(self):
        if self.data is not None:
            self._update_curve_points()
        self._update_table()

    def _invalidate_results(self):
        for id in self.learners:
            self.curves[id] = None
            self.results[id] = None
        self._update()

# [start-snippet-5]
    def _update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        if self.data is None:
            return
        # collect all learners for which results have not yet been computed
        need_update = [(id, learner) for id, learner in self.learners.items()
                       if self.results[id] is None]
        if not need_update:
            return
# [end-snippet-5]
# [start-snippet-6]
        learners = [learner for _, learner in need_update]
        # setup the learner evaluations as partial function capturing
        # the necessary arguments.
        if self.testdata is None:
            learning_curve_func = partial(
                learning_curve,
                learners, self.data, folds=self.folds,
                proportions=self.curvePoints,
            )
        else:
            learning_curve_func = partial(
                learning_curve_with_test_data,
                learners, self.data, self.testdata, times=self.folds,
                proportions=self.curvePoints,
            )
# [end-snippet-6]
# [start-snippet-7]
        # setup the task state
        self._task = task = Task()
        # The learning_curve[_with_test_data] also takes a callback function
        # to report the progress. We instrument this callback to both invoke
        # the appropriate slots on this widget for reporting the progress
        # (in a thread safe manner) and to implement cooperative cancellation.
        set_progress = methodinvoke(self, "setProgressValue", (float,))

        def callback(finished):
            # check if the task has been cancelled and raise an exception
            # from within. This 'strategy' can only be used with code that
            # properly cleans up after itself in the case of an exception
            # (does not leave any global locks, opened file descriptors, ...)
            if task.cancelled:
                raise KeyboardInterrupt()
            set_progress(finished * 100)

        # capture the callback in the partial function
        learning_curve_func = partial(learning_curve_func, callback=callback)
# [end-snippet-7]
# [start-snippet-8]
        self.progressBarInit()
        # Submit the evaluation function to the executor and fill in the
        # task with the resultant Future.
        task.future = self._executor.submit(learning_curve_func)
        # Setup the FutureWatcher to notify us of completion
        task.watcher = FutureWatcher(task.future)
        # by using FutureWatcher we ensure `_task_finished` slot will be
        # called from the main GUI thread by the Qt's event loop
        task.watcher.done.connect(self._task_finished)
# [end-snippet-8]

# [start-snippet-progress]
    @pyqtSlot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)
# [end-snippet-progress]

# [start-snippet-9]
    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the result of learner evaluation.
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        try:
            results = f.result()  # type: List[Results]
        except Exception as ex:
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occurred during evaluation: {!r}"
                       .format(ex))
            # clear all results
            for key in self.results.keys():
                self.results[key] = None
        else:
            # split the combined result into per learner/model results ...
            results = [list(Results.split_by_model(p_results))
                       for p_results in results]  # type: List[List[Results]]
            assert all(len(r.learners) == 1 for r1 in results for r in r1)
            assert len(results) == len(self.curvePoints)

            learners = [r.learners[0] for r in results[0]]
            learner_id = {learner: id_ for id_, learner in self.learners.items()}

            # ... and update self.results
            for i, learner in enumerate(learners):
                id_ = learner_id[learner]
                self.results[id_] = [p_results[i] for p_results in results]
# [end-snippet-9]
        # update the display
        self._update_curve_points()
        self._update_table()
# [end-snippet-9]

# [start-snippet-10]
    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self._task = None
# [end-snippet-10]

# [start-snippet-11]
    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
# [end-snippet-11]

    def _update_curve_points(self):
        for id in self.learners:
            curve = [self.scoring[self.scoringF][1](x)[0]
                     for x in self.results[id]]
            self.curves[id] = curve

    def _update_table(self):
        self.table.setRowCount(0)
        self.table.setRowCount(len(self.curvePoints))
        self.table.setColumnCount(len(self.learners))

        self.table.setHorizontalHeaderLabels(
            [learner.name for _, learner in self.learners.items()])
        self.table.setVerticalHeaderLabels(
            ["{:.2f}".format(p) for p in self.curvePoints])

        if self.data is None:
            return

        for column, curve in enumerate(self.curves.values()):
            for row, point in enumerate(curve):
                self.table.setItem(
                    row, column, QTableWidgetItem("{:.5f}".format(point)))

        for i in range(len(self.learners)):
            sh = self.table.sizeHintForColumn(i)
            cwidth = self.table.columnWidth(i)
            self.table.setColumnWidth(i, max(sh, cwidth))

    def updateCurvePoints(self):
        self.curvePoints = [(x + 1.)/self.steps for x in range(self.steps)]
class OWGeneInfo(widget.OWWidget):
    name = "Gene Info"
    description = "Displays gene information from NCBI and other sources."
    icon = "../widgets/icons/OWGeneInfo.svg"
    priority = 5

    class Inputs:
        data = Input("Data", Orange.data.Table)

    class Outputs:
        selected_genes = Output("Selected Genes", Orange.data.Table)
        data = Output("Data", Orange.data.Table)

    settingsHandler = settings.DomainContextHandler()

    organism_index = settings.ContextSetting(0)
    taxid = settings.ContextSetting("9606")

    gene_attr = settings.ContextSetting(0)

    auto_commit = settings.Setting(False)
    search_string = settings.Setting("")

    useAttr = settings.ContextSetting(False)
    useAltSource = settings.ContextSetting(False)

    def __init__(
        self,
        parent=None,
    ):
        super().__init__(self, parent)

        self.selectionChangedFlag = False

        self.__initialized = False
        self.initfuture = None
        self.itemsfuture = None

        self.map_input_to_ensembl = None
        self.infoLabel = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info", addSpace=True),
            "Initializing\n")

        self.organisms = None
        self.organismBox = gui.widgetBox(self.controlArea,
                                         "Organism",
                                         addSpace=True)

        self.organismComboBox = gui.comboBox(
            self.organismBox,
            self,
            "organism_index",
            callback=self._onSelectedOrganismChanged)

        box = gui.widgetBox(self.controlArea, "Gene names", addSpace=True)
        self.geneAttrComboBox = gui.comboBox(box,
                                             self,
                                             "gene_attr",
                                             "Gene attribute",
                                             callback=self.updateInfoItems)
        self.geneAttrComboBox.setEnabled(not self.useAttr)

        self.geneAttrCheckbox = gui.checkBox(box,
                                             self,
                                             "useAttr",
                                             "Use column names",
                                             callback=self.updateInfoItems)
        self.geneAttrCheckbox.toggled[bool].connect(
            self.geneAttrComboBox.setDisabled)

        gui.auto_commit(self.controlArea, self, "auto_commit", "Commit")

        gui.rubber(self.controlArea)

        gui.lineEdit(self.mainArea,
                     self,
                     "search_string",
                     "Filter",
                     callbackOnType=True,
                     callback=self.searchUpdate)

        self.treeWidget = QTreeView(self.mainArea)

        self.treeWidget.setAlternatingRowColors(True)
        self.treeWidget.setSortingEnabled(True)
        self.treeWidget.setSelectionMode(QTreeView.ExtendedSelection)
        self.treeWidget.setUniformRowHeights(True)
        self.treeWidget.setRootIsDecorated(False)

        self.treeWidget.setItemDelegateForColumn(
            HEADER_SCHEMA['NCBI ID'],
            gui.LinkStyledItemDelegate(self.treeWidget))
        self.treeWidget.setItemDelegateForColumn(
            HEADER_SCHEMA['Ensembl ID'],
            gui.LinkStyledItemDelegate(self.treeWidget))

        self.treeWidget.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.treeWidget)

        box = gui.widgetBox(self.mainArea, "", orientation="horizontal")
        gui.button(box, self, "Select Filtered", callback=self.selectFiltered)
        gui.button(box,
                   self,
                   "Clear Selection",
                   callback=self.treeWidget.clearSelection)

        self.geneinfo = []
        self.cells = []
        self.row2geneinfo = {}
        self.data = None

        # : (# input genes, # matches genes)
        self.matchedInfo = 0, 0

        self.setBlocking(True)
        self.executor = ThreadExecutor(self)

        self.progressBarInit()

        task = Task(
            function=partial(taxonomy.ensure_downloaded,
                             callback=methodinvoke(self, "advance", ())))

        task.resultReady.connect(self.initialize)
        task.exceptionReady.connect(self._onInitializeError)

        self.initfuture = self.executor.submit(task)

    def sizeHint(self):
        return QSize(1024, 720)

    @Slot()
    def advance(self):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(self.progressBarValue + 1, processEvents=None)

    def _get_available_organisms(self):
        available_organism = sorted([(tax_id, taxonomy.name(tax_id))
                                     for tax_id in taxonomy.common_taxids()],
                                    key=lambda x: x[1])

        self.organisms = [tax_id[0] for tax_id in available_organism]

        self.organismComboBox.addItems(
            [tax_id[1] for tax_id in available_organism])

    def initialize(self):
        if self.__initialized:
            # Already initialized
            return
        self.__initialized = True

        self._get_available_organisms()
        self.organism_index = self.organisms.index(taxonomy.DEFAULT_ORGANISM)
        self.taxid = self.organisms[self.organism_index]

        self.infoLabel.setText("No data on input\n")
        self.initfuture = None

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)

    def _onInitializeError(self, exc):
        sys.excepthook(type(exc), exc, None)
        self.error(0, "Could not download the necessary files.")

    def _onSelectedOrganismChanged(self):
        assert 0 <= self.organism_index <= len(self.organisms)
        self.taxid = self.organisms[self.organism_index]

        if self.data is not None:
            self.updateInfoItems()

    @Inputs.data
    def setData(self, data=None):
        if not self.__initialized:
            self.initfuture.result()
            self.initialize()

        if self.itemsfuture is not None:
            raise Exception("Already processing")

        self.data = data

        if data is not None:
            self.geneAttrComboBox.clear()
            self.attributes = [
                attr for attr in data.domain.variables + data.domain.metas
                if isinstance(attr, (Orange.data.StringVariable,
                                     Orange.data.DiscreteVariable))
            ]

            for var in self.attributes:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            self.taxid = str(self.data.attributes.get(TAX_ID, ''))
            self.useAttr = self.data.attributes.get(GENE_AS_ATTRIBUTE_NAME,
                                                    self.useAttr)

            self.gene_attr = min(self.gene_attr, len(self.attributes) - 1)

            if self.taxid in self.organisms:
                self.organism_index = self.organisms.index(self.taxid)

            self.updateInfoItems()
        else:
            self.clear()

    def updateInfoItems(self):
        self.warning(0)
        if self.data is None:
            return

        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [
                str(ex[attr]) for ex in self.data if not math.isnan(ex[attr])
            ]
        else:
            genes = []
        if not genes:
            self.warning(0, "Could not extract genes from input dataset.")

        self.warning(1)
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        source_name, info_getter = ("NCBI Info", ncbi_info)

        self.error(0)

        self.progressBarInit()
        self.setBlocking(True)
        self.setEnabled(False)
        self.infoLabel.setText("Retrieving info records.\n")

        self.genes = genes

        task = Task(function=partial(
            info_getter, org, genes, advance=methodinvoke(self, "advance", (
            ))))
        self.itemsfuture = self.executor.submit(task)
        task.finished.connect(self._onItemsCompleted)

    def _onItemsCompleted(self):
        self.setBlocking(False)
        self.progressBarFinished()
        self.setEnabled(True)

        try:
            self.map_input_to_ensembl, geneinfo = self.itemsfuture.result()
        finally:
            self.itemsfuture = None

        self.geneinfo = geneinfo
        self.cells = cells = []
        self.row2geneinfo = {}

        for i, (input_name, gi) in enumerate(geneinfo):
            if gi:
                row = []
                for item in gi:
                    row.append(item)

                # parse synonyms
                row[HEADER_SCHEMA['Synonyms']] = ','.join(
                    row[HEADER_SCHEMA['Synonyms']])
                cells.append(row)
                self.row2geneinfo[len(cells) - 1] = i

        model = TreeModel(cells, list(HEADER_SCHEMA.keys()), None)

        proxyModel = QSortFilterProxyModel(self)
        proxyModel.setSourceModel(model)
        self.treeWidget.setModel(proxyModel)
        self.treeWidget.selectionModel().selectionChanged.connect(self.commit)

        for i in range(len(HEADER_SCHEMA)):
            self.treeWidget.resizeColumnToContents(i)
            self.treeWidget.setColumnWidth(
                i, min(self.treeWidget.columnWidth(i), 200))

        self.infoLabel.setText("%i genes\n%i matched NCBI's IDs" %
                               (len(self.genes), len(cells)))
        self.matchedInfo = len(self.genes), len(cells)

        if self.useAttr:
            new_data = self.data.from_table(self.data.domain, self.data)

            for gene_var in new_data.domain.attributes:
                gene_var.attributes['Ensembl ID'] = str(
                    self.map_input_to_ensembl[gene_var.name])

            self.Outputs.data.send(new_data)

        elif self.attributes:
            ensembl_ids = []
            for gene_name in self.data.get_column_view(
                    self.attributes[self.gene_attr])[0]:
                if gene_name and gene_name in self.map_input_to_ensembl:
                    ensembl_ids.append(self.map_input_to_ensembl[gene_name])
                else:
                    ensembl_ids.append('')

            data_with_ensembl = append_columns(
                self.data,
                metas=[(Orange.data.StringVariable('Ensembl ID'), ensembl_ids)
                       ])
            self.Outputs.data.send(data_with_ensembl)

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.treeWidget.setModel(
            TreeModel([], [
                "NCBI ID", "Symbol", "Locus Tag", "Chromosome", "Description",
                "Synonyms", "Nomenclature"
            ], self.treeWidget))

        self.geneAttrComboBox.clear()
        self.Outputs.selected_genes.send(None)

    def commit(self):
        if self.data is None:
            self.Outputs.selected_genes.send(None)
            self.Outputs.data.send(None)
            return

        model = self.treeWidget.model()
        selection = self.treeWidget.selectionModel().selection()
        selection = model.mapSelectionToSource(selection)
        selectedRows = list(
            chain(*(range(r.top(),
                          r.bottom() + 1) for r in selection)))
        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)
        gene2row = dict((self.geneinfo[self.row2geneinfo[row]][0], row)
                        for row in selectedRows)

        isselected = selectedIds.__contains__

        if selectedIds:

            if self.useAttr:
                attrs = [
                    attr for attr in self.data.domain.attributes
                    if isselected(attr.name)
                ]
                domain = Orange.data.Domain(attrs, self.data.domain.class_vars,
                                            self.data.domain.metas)
                newdata = self.data.from_table(domain, self.data)

                self.Outputs.selected_genes.send(newdata)

            elif self.attributes:
                attr = self.attributes[self.gene_attr]
                gene_col = [
                    attr.str_val(v) for v in self.data.get_column_view(attr)[0]
                ]
                gene_col = [(i, name) for i, name in enumerate(gene_col)
                            if isselected(name)]
                indices = [i for i, _ in gene_col]

                # SELECTED GENES OUTPUT
                selected_genes_metas = [
                    Orange.data.StringVariable(name)
                    for name in gene.GENE_INFO_HEADER_LABELS
                ]
                selected_genes_domain = Orange.data.Domain(
                    self.data.domain.attributes, self.data.domain.class_vars,
                    self.data.domain.metas + tuple(selected_genes_metas))

                selected_genes_data = self.data.from_table(
                    selected_genes_domain, self.data)[indices]

                model_rows = [gene2row[gene_name] for _, gene_name in gene_col]
                for col, meta in zip(range(model.columnCount()),
                                     selected_genes_metas):
                    col_data = [
                        str(model.index(row, col).data(Qt.DisplayRole))
                        for row in model_rows
                    ]
                    col_data = np.array(col_data, dtype=object, ndmin=2).T
                    selected_genes_data[:, meta] = col_data

                if not len(selected_genes_data):
                    selected_genes_data = None

                self.Outputs.selected_genes.send(selected_genes_data)
        else:
            self.Outputs.selected_genes.send(None)

    def rowFiltered(self, row):
        searchStrings = self.search_string.lower().split()
        row = " ".join(self.cells[row]).lower()
        return not all([s in row for s in searchStrings])

    def searchUpdate(self):
        if not self.data:
            return
        searchStrings = self.search_string.lower().split()
        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            row = " ".join(row).lower()
            self.treeWidget.setRowHidden(
                mapFromSource(index(i, 0)).row(), QModelIndex(),
                not all([s in row for s in searchStrings]))

    def selectFiltered(self):
        if not self.data:
            return
        itemSelection = QItemSelection()

        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            if not self.rowFiltered(i):
                itemSelection.select(mapFromSource(index(i, 0)),
                                     mapFromSource(index(i, 0)))
        self.treeWidget.selectionModel().select(
            itemSelection,
            QItemSelectionModel.Select | QItemSelectionModel.Rows)

    def onAltSourceChange(self):
        self.updateInfoItems()

    def onDeleteWidget(self):
        # try to cancel pending tasks
        if self.initfuture:
            self.initfuture.cancel()
        if self.itemsfuture:
            self.itemsfuture.cancel()

        self.executor.shutdown(wait=False)
        super().onDeleteWidget()
Ejemplo n.º 38
0
class OWNNLearner(OWBaseLearner):
    name = "Neural Network"
    description = "A multi-layer perceptron (MLP) algorithm with " \
                  "backpropagation."
    icon = "icons/NN.svg"
    priority = 90
    keywords = ["mlp"]

    LEARNER = NNLearner

    activation = ["identity", "logistic", "tanh", "relu"]
    act_lbl = ["Identity", "Logistic", "tanh", "ReLu"]
    solver = ["lbfgs", "sgd", "adam"]
    solv_lbl = ["L-BFGS-B", "SGD", "Adam"]

    learner_name = Setting("Neural Network")
    hidden_layers_input = Setting("100,")
    activation_index = Setting(3)
    solver_index = Setting(2)
    max_iterations = Setting(200)
    alpha_index = Setting(0)
    settings_version = 1

    alphas = list(
        chain([x / 10000 for x in range(1, 10)],
              [x / 1000 for x in range(1, 10)],
              [x / 100 for x in range(1, 10)], [x / 10 for x in range(1, 10)],
              range(1, 10), range(10, 100, 5), range(100, 200, 10),
              range(100, 1001, 50)))

    def add_main_layout(self):
        form = QFormLayout()
        form.setFieldGrowthPolicy(form.AllNonFixedFieldsGrow)
        form.setVerticalSpacing(25)
        gui.widgetBox(self.controlArea, True, orientation=form)
        form.addRow(
            "Neurons in hidden layers:",
            gui.lineEdit(
                None,
                self,
                "hidden_layers_input",
                orientation=Qt.Horizontal,
                callback=self.settings_changed,
                tooltip="A list of integers defining neurons. Length of list "
                "defines the number of layers. E.g. 4, 2, 2, 3.",
                placeholderText="e.g. 100,"))
        form.addRow(
            "Activation:",
            gui.comboBox(None,
                         self,
                         "activation_index",
                         orientation=Qt.Horizontal,
                         label="Activation:",
                         items=[i for i in self.act_lbl],
                         callback=self.settings_changed))

        form.addRow(" ", gui.separator(None, 16))
        form.addRow(
            "Solver:",
            gui.comboBox(None,
                         self,
                         "solver_index",
                         orientation=Qt.Horizontal,
                         label="Solver:",
                         items=[i for i in self.solv_lbl],
                         callback=self.settings_changed))
        self.reg_label = QLabel()
        slider = gui.hSlider(None,
                             self,
                             "alpha_index",
                             minValue=0,
                             maxValue=len(self.alphas) - 1,
                             callback=lambda:
                             (self.set_alpha(), self.settings_changed()),
                             createLabel=False)
        form.addRow(self.reg_label, slider)
        self.set_alpha()

        form.addRow(
            "Maximal number of iterations:",
            gui.spin(None,
                     self,
                     "max_iterations",
                     10,
                     10000,
                     step=10,
                     label="Max iterations:",
                     orientation=Qt.Horizontal,
                     alignment=Qt.AlignRight,
                     callback=self.settings_changed))

    def set_alpha(self):
        self.strength_C = self.alphas[self.alpha_index]
        self.reg_label.setText("Regularization, α={}:".format(self.strength_C))

    @property
    def alpha(self):
        return self.alphas[self.alpha_index]

    def setup_layout(self):
        super().setup_layout()

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # just a test cancel button
        gui.button(self.apply_button, self, "Cancel", callback=self.cancel)

    def create_learner(self):
        return self.LEARNER(hidden_layer_sizes=self.get_hidden_layers(),
                            activation=self.activation[self.activation_index],
                            solver=self.solver[self.solver_index],
                            alpha=self.alpha,
                            max_iter=self.max_iterations,
                            preprocessors=self.preprocessors)

    def get_learner_parameters(self):
        return (("Hidden layers", ', '.join(map(str,
                                                self.get_hidden_layers()))),
                ("Activation", self.act_lbl[self.activation_index]),
                ("Solver", self.solv_lbl[self.solver_index]),
                ("Alpha", self.alpha), ("Max iterations", self.max_iterations))

    def get_hidden_layers(self):
        layers = tuple(map(int, re.findall(r'\d+', self.hidden_layers_input)))
        if not layers:
            layers = (100, )
            self.hidden_layers_input = "100,"
        return layers

    def update_model(self):
        self.show_fitting_failed(None)
        self.model = None
        if self.check_data():
            self.__update()
        else:
            self.Outputs.model.send(self.model)

    @Slot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    def __update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        max_iter = self.learner.kwargs["max_iter"]

        # Setup the task state
        task = Task()
        lastemitted = 0.

        def callback(iteration):
            nonlocal task  # type: Task
            nonlocal lastemitted
            if task.isInterruptionRequested():
                raise CancelTaskException()
            progress = round(iteration / max_iter * 100)
            if progress != lastemitted:
                task.emitProgressUpdate(progress)
                lastemitted = progress

        # copy to set the callback so that the learner output is not modified
        # (currently we can not pass callbacks to learners __call__)
        learner = copy.copy(self.learner)
        learner.callback = callback

        def build_model(data, learner):
            try:
                return learner(data)
            except CancelTaskException:
                return None

        build_model_func = partial(build_model, self.data, learner)

        task.setFuture(self._executor.submit(build_model_func))
        task.done.connect(self._task_finished)
        task.progressChanged.connect(self.setProgressValue)

        self._task = task
        self.progressBarInit()
        self.setBlocking(True)

    @Slot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the built model
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()
        self._task.deleteLater()
        self._task = None
        self.setBlocking(False)
        self.progressBarFinished()

        try:
            self.model = f.result()
        except Exception as ex:  # pylint: disable=broad-except
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.model = None
            self.show_fitting_failed(ex)
        else:
            self.model.name = self.learner_name
            self.model.instances = self.data
            self.model.skl_model.orange_callback = None  # remove unpicklable callback
            self.Outputs.model.send(self.model)

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect from the task
            self._task.done.disconnect(self._task_finished)
            self._task.progressChanged.disconnect(self.setProgressValue)
            self._task.deleteLater()
            self._task = None

        self.progressBarFinished()
        self.setBlocking(False)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    @classmethod
    def migrate_settings(cls, settings, version):
        if not version:
            alpha = settings.pop("alpha", None)
            if alpha is not None:
                settings["alpha_index"] = \
                    np.argmin(np.abs(np.array(cls.alphas) - alpha))
Ejemplo n.º 39
0
class OWSetEnrichment(widget.OWWidget):
    name = "Set Enrichment"
    description = ""
    icon = "../widgets/icons/GeneSetEnrichment.svg"
    priority = 5000

    inputs = [("Data", Orange.data.Table, "setData", widget.Default),
              ("Reference", Orange.data.Table, "setReference")]
    outputs = [("Data subset", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    taxid = settings.ContextSetting(None)
    speciesIndex = settings.ContextSetting(0)
    genesinrows = settings.ContextSetting(False)
    geneattr = settings.ContextSetting(0)
    categoriesCheckState = settings.ContextSetting({})

    useReferenceData = settings.Setting(False)
    useMinCountFilter = settings.Setting(True)
    useMaxPValFilter = settings.Setting(True)
    useMaxFDRFilter = settings.Setting(True)
    minClusterCount = settings.Setting(3)
    maxPValue = settings.Setting(0.01)
    maxFDR = settings.Setting(0.01)
    autocommit = settings.Setting(False)

    Ready, Initializing, Loading, RunningEnrichment = 0, 1, 2, 4

    def __init__(self, parent=None):
        super().__init__(parent)

        self.geneMatcherSettings = [False, False, True, False]

        self.data = None
        self.referenceData = None
        self.taxid_list = []

        self.__genematcher = (None, fulfill(gene.matcher([])))
        self.__invalidated = False

        self.currentAnnotatedCategories = []
        self.state = None
        self.__state = OWSetEnrichment.Initializing

        box = gui.widgetBox(self.controlArea, "Info")
        self.infoBox = gui.widgetLabel(box, "Info")
        self.infoBox.setText("No data on input.\n")

        self.speciesComboBox = gui.comboBox(
            self.controlArea,
            self,
            "speciesIndex",
            "Species",
            callback=self.__on_speciesIndexChanged)

        box = gui.widgetBox(self.controlArea, "Entity names")
        self.geneAttrComboBox = gui.comboBox(box,
                                             self,
                                             "geneattr",
                                             "Entity feature",
                                             sendSelectedValue=0,
                                             callback=self.updateAnnotations)

        cb = gui.checkBox(box,
                          self,
                          "genesinrows",
                          "Use feature names",
                          callback=self.updateAnnotations,
                          disables=[(-1, self.geneAttrComboBox)])
        cb.makeConsistent()

        #         gui.button(box, self, "Gene matcher settings",
        #                    callback=self.updateGeneMatcherSettings,
        #                    tooltip="Open gene matching settings dialog")

        self.referenceRadioBox = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "useReferenceData", ["All entities", "Reference set (input)"],
            tooltips=[
                "Use entire genome (for gene set enrichment) or all " +
                "available entities for reference",
                "Use entities from Reference Examples input signal " +
                "as reference"
            ],
            box="Reference",
            callback=self.updateAnnotations)

        box = gui.widgetBox(self.controlArea, "Entity Sets")
        self.groupsWidget = QTreeWidget(self)
        self.groupsWidget.setHeaderLabels(["Category"])
        box.layout().addWidget(self.groupsWidget)

        hLayout = QHBoxLayout()
        hLayout.setSpacing(10)
        hWidget = gui.widgetBox(self.mainArea, orientation=hLayout)
        gui.spin(hWidget,
                 self,
                 "minClusterCount",
                 0,
                 100,
                 label="Entities",
                 tooltip="Minimum entity count",
                 callback=self.filterAnnotationsChartView,
                 callbackOnReturn=True,
                 checked="useMinCountFilter",
                 checkCallback=self.filterAnnotationsChartView)

        pvalfilterbox = gui.widgetBox(hWidget, orientation="horizontal")
        cb = gui.checkBox(pvalfilterbox,
                          self,
                          "useMaxPValFilter",
                          "p-value",
                          callback=self.filterAnnotationsChartView)

        sp = gui.doubleSpin(
            pvalfilterbox,
            self,
            "maxPValue",
            0.0,
            1.0,
            0.0001,
            tooltip="Maximum p-value",
            callback=self.filterAnnotationsChartView,
            callbackOnReturn=True,
        )
        sp.setEnabled(self.useMaxFDRFilter)
        cb.toggled[bool].connect(sp.setEnabled)

        pvalfilterbox.layout().setAlignment(cb, Qt.AlignRight)
        pvalfilterbox.layout().setAlignment(sp, Qt.AlignLeft)

        fdrfilterbox = gui.widgetBox(hWidget, orientation="horizontal")
        cb = gui.checkBox(fdrfilterbox,
                          self,
                          "useMaxFDRFilter",
                          "FDR",
                          callback=self.filterAnnotationsChartView)

        sp = gui.doubleSpin(
            fdrfilterbox,
            self,
            "maxFDR",
            0.0,
            1.0,
            0.0001,
            tooltip="Maximum False discovery rate",
            callback=self.filterAnnotationsChartView,
            callbackOnReturn=True,
        )
        sp.setEnabled(self.useMaxFDRFilter)
        cb.toggled[bool].connect(sp.setEnabled)

        fdrfilterbox.layout().setAlignment(cb, Qt.AlignRight)
        fdrfilterbox.layout().setAlignment(sp, Qt.AlignLeft)

        self.filterLineEdit = QLineEdit(self, placeholderText="Filter ...")

        self.filterCompleter = QCompleter(self.filterLineEdit)
        self.filterCompleter.setCaseSensitivity(Qt.CaseInsensitive)
        self.filterLineEdit.setCompleter(self.filterCompleter)

        hLayout.addWidget(self.filterLineEdit)
        self.mainArea.layout().addWidget(hWidget)

        self.filterLineEdit.textChanged.connect(
            self.filterAnnotationsChartView)

        self.annotationsChartView = QTreeView(
            alternatingRowColors=True,
            sortingEnabled=True,
            selectionMode=QTreeView.ExtendedSelection,
            rootIsDecorated=False,
            editTriggers=QTreeView.NoEditTriggers,
        )
        self.annotationsChartView.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.annotationsChartView)

        contextEventFilter = gui.VisibleHeaderSectionContextEventFilter(
            self.annotationsChartView)
        self.annotationsChartView.header().installEventFilter(
            contextEventFilter)

        self.groupsWidget.itemClicked.connect(self.subsetSelectionChanged)
        gui.auto_commit(self.controlArea, self, "autocommit", "Commit")

        self.setBlocking(True)

        task = EnsureDownloaded([(taxonomy.Taxonomy.DOMAIN,
                                  taxonomy.Taxonomy.FILENAME),
                                 (geneset.sfdomain, "index.pck")])

        task.finished.connect(self.__initialize_finish)
        self.setStatusMessage("Initializing")
        self._executor = ThreadExecutor(parent=self,
                                        threadPool=QThreadPool(self))
        self._executor.submit(task)

    def sizeHint(self):
        return QSize(1024, 600)

    def __initialize_finish(self):
        # Finalize the the widget's initialization (preferably after
        # ensuring all required databases have been downloaded.

        sets = geneset.list_all()
        taxids = set(taxonomy.common_taxids() +
                     list(filter(None, [tid for _, tid, _ in sets])))
        organisms = [(tid, name_or_none(tid)) for tid in taxids]
        organisms = [(tid, name) for tid, name in organisms
                     if name is not None]

        organisms = [(None, "None")] + sorted(organisms)
        taxids = [tid for tid, _ in organisms]
        names = [name for _, name in organisms]
        self.taxid_list = taxids

        self.speciesComboBox.clear()
        self.speciesComboBox.addItems(names)
        self.genesets = sets

        if self.taxid in self.taxid_list:
            taxid = self.taxid
        else:
            taxid = self.taxid_list[0]

        self.taxid = None
        self.setCurrentOrganism(taxid)
        self.setBlocking(False)
        self.__state = OWSetEnrichment.Ready
        self.setStatusMessage("")

    def setCurrentOrganism(self, taxid):
        """Set the current organism `taxid`."""
        if taxid not in self.taxid_list:
            taxid = self.taxid_list[min(self.speciesIndex,
                                        len(self.taxid_list) - 1)]
        if self.taxid != taxid:
            self.taxid = taxid
            self.speciesIndex = self.taxid_list.index(taxid)
            self.refreshHierarchy()
            self._invalidateGeneMatcher()
            self._invalidate()

    def currentOrganism(self):
        """Return the current organism taxid"""
        return self.taxid

    def __on_speciesIndexChanged(self):
        taxid = self.taxid_list[self.speciesIndex]
        self.taxid = "< Do not look >"
        self.setCurrentOrganism(taxid)
        if self.__invalidated and self.data is not None:
            self.updateAnnotations()

    def clear(self):
        """Clear/reset the widget state."""
        self._cancelPending()
        self.state = None

        self.__state = self.__state & ~OWSetEnrichment.RunningEnrichment

        self._clearView()

        if self.annotationsChartView.model() is not None:
            self.annotationsChartView.model().clear()

        self.geneAttrComboBox.clear()
        self.geneAttrs = []
        self._updatesummary()

    def _cancelPending(self):
        """Cancel pending tasks."""
        if self.state is not None:
            self.state.results.cancel()
            self.state.namematcher.cancel()
            self.state.cancelled = True

    def _clearView(self):
        """Clear the enrichment report view (main area)."""
        if self.annotationsChartView.model() is not None:
            self.annotationsChartView.model().clear()

    def setData(self, data=None):
        """Set the input dataset with query gene names"""
        if self.__state & OWSetEnrichment.Initializing:
            self.__initialize_finish()

        self.error(0)
        self.closeContext()
        self.clear()

        self.groupsWidget.clear()
        self.data = data

        if data is not None:
            varlist = [
                var for var in data.domain.variables + data.domain.metas
                if isinstance(var, Orange.data.StringVariable)
            ]

            self.geneAttrs = varlist
            for var in varlist:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            oldtaxid = self.taxid
            self.geneattr = min(self.geneattr, len(self.geneAttrs) - 1)

            taxid = data_hints.get_hint(data, "taxid", "")
            if taxid in self.taxid_list:
                self.speciesIndex = self.taxid_list.index(taxid)
                self.taxid = taxid

            self.genesinrows = data_hints.get_hint(data, "genesinrows",
                                                   self.genesinrows)

            self.openContext(data)
            if oldtaxid != self.taxid:
                self.taxid = "< Do not look >"
                self.setCurrentOrganism(taxid)

            self.refreshHierarchy()
            self._invalidate()

    def setReference(self, data=None):
        """Set the (optional) input dataset with reference gene names."""
        self.referenceData = data
        self.referenceRadioBox.setEnabled(bool(data))
        if self.useReferenceData:
            self._invalidate()

    def handleNewSignals(self):
        if self.__invalidated:
            self.updateAnnotations()

    def _invalidateGeneMatcher(self):
        _, f = self.__genematcher
        f.cancel()
        self.__genematcher = (None, fulfill(gene.matcher([])))

    def _invalidate(self):
        self.__invalidated = True

    def genesFromTable(self, table):
        if self.genesinrows:
            genes = [attr.name for attr in table.domain.attributes]
        else:
            geneattr = self.geneAttrs[self.geneattr]
            genes = [str(ex[geneattr]) for ex in table]
        return genes

    def getHierarchy(self, taxid):
        def recursive_dict():
            return defaultdict(recursive_dict)

        collection = recursive_dict()

        def collect(col, hier):
            if hier:
                collect(col[hier[0]], hier[1:])

        for hierarchy, t_id, _ in self.genesets:
            collect(collection[t_id], hierarchy)

        return (taxid, collection[taxid]), (None, collection[None])

    def setHierarchy(self, hierarchy, hierarchy_noorg):
        self.groupsWidgetItems = {}

        def fill(col, parent, full=(), org=""):
            for key, value in sorted(col.items()):
                full_cat = full + (key, )
                item = QTreeWidgetItem(parent, [key])
                item.setFlags(item.flags() | Qt.ItemIsUserCheckable
                              | Qt.ItemIsSelectable | Qt.ItemIsEnabled)
                if value:
                    item.setFlags(item.flags() | Qt.ItemIsTristate)

                checked = self.categoriesCheckState.get((full_cat, org),
                                                        Qt.Checked)
                item.setData(0, Qt.CheckStateRole, checked)
                item.setExpanded(True)
                item.category = full_cat
                item.organism = org
                self.groupsWidgetItems[full_cat] = item
                fill(value, item, full_cat, org=org)

        self.groupsWidget.clear()
        fill(hierarchy[1], self.groupsWidget, org=hierarchy[0])
        fill(hierarchy_noorg[1], self.groupsWidget, org=hierarchy_noorg[0])

    def refreshHierarchy(self):
        self.setHierarchy(*self.getHierarchy(
            taxid=self.taxid_list[self.speciesIndex]))

    def selectedCategories(self):
        """
        Return a list of currently selected hierarchy keys.

        A key is a tuple of identifiers from the root to the leaf of
        the hierarchy tree.
        """
        return [
            key for key, check in self.getHierarchyCheckState().items()
            if check == Qt.Checked
        ]

    def getHierarchyCheckState(self):
        def collect(item, full=()):
            checked = item.checkState(0)
            name = str(item.data(0, Qt.DisplayRole))
            full_cat = full + (name, )
            result = [((full_cat, item.organism), checked)]
            for i in range(item.childCount()):
                result.extend(collect(item.child(i), full_cat))
            return result

        items = [
            self.groupsWidget.topLevelItem(i)
            for i in range(self.groupsWidget.topLevelItemCount())
        ]
        states = itertools.chain(*(collect(item) for item in items))
        return dict(states)

    def subsetSelectionChanged(self, item, column):
        # The selected geneset (hierarchy) subset has been changed by the
        # user. Update the displayed results.
        # Update the stored state (persistent settings)
        self.categoriesCheckState = self.getHierarchyCheckState()
        categories = self.selectedCategories()

        if self.data is not None:
            if self._nogenematching() or \
                    not set(categories) <= set(self.currentAnnotatedCategories):
                self.updateAnnotations()
            else:
                self.filterAnnotationsChartView()

    def updateGeneMatcherSettings(self):
        raise NotImplementedError

        from .OWGOEnrichmentAnalysis import GeneMatcherDialog
        dialog = GeneMatcherDialog(self,
                                   defaults=self.geneMatcherSettings,
                                   enabled=[True] * 4,
                                   modal=True)
        if dialog.exec_():
            self.geneMatcherSettings = [
                getattr(dialog, item[0]) for item in dialog.items
            ]
            self._invalidateGeneMatcher()
            if self.data is not None:
                self.updateAnnotations()

    def _genematcher(self):
        """
        Return a Future[gene.SequenceMatcher]
        """
        taxid = self.taxid_list[self.speciesIndex]

        current, matcher_f = self.__genematcher

        if taxid == current and \
                not matcher_f.cancelled():
            return matcher_f

        self._invalidateGeneMatcher()

        if taxid is None:
            self.__genematcher = (None, fulfill(gene.matcher([])))
            return self.__genematcher[1]

        matchers = [gene.GMGO, gene.GMKEGG, gene.GMNCBI, gene.GMAffy]
        matchers = [
            m for m, use in zip(matchers, self.geneMatcherSettings) if use
        ]

        def create():
            return gene.matcher([m(taxid) for m in matchers])

        matcher_f = self._executor.submit(create)
        self.__genematcher = (taxid, matcher_f)
        return self.__genematcher[1]

    def _nogenematching(self):
        return self.taxid is None or not any(self.geneMatcherSettings)

    def updateAnnotations(self):
        if self.data is None:
            return

        assert not self.__state & OWSetEnrichment.Initializing
        self._cancelPending()
        self._clearView()

        self.information(0)
        self.warning(0)
        self.error(0)

        if not self.genesinrows and len(self.geneAttrs) == 0:
            self.error(0, "Input data contains no columns with gene names")
            return

        self.__state = OWSetEnrichment.RunningEnrichment

        taxid = self.taxid_list[self.speciesIndex]
        self.taxid = taxid

        categories = self.selectedCategories()

        clusterGenes = self.genesFromTable(self.data)

        if self.referenceData is not None and self.useReferenceData:
            referenceGenes = self.genesFromTable(self.referenceData)
        else:
            referenceGenes = None

        self.currentAnnotatedCategories = categories

        genematcher = self._genematcher()

        self.progressBarInit()

        ## Load collections in a worker thread
        # TODO: Use cached collections if already loaded and
        # use ensure_genesetsdownloaded with progress report (OWSelectGenes)
        collections = self._executor.submit(geneset.collections, *categories)

        def refset_null():
            """Return the default background reference set"""
            col = collections.result()
            return reduce(operator.ior, (set(g.genes) for g in col), set())

        def refset_ncbi():
            """Return all NCBI gene names"""
            geneinfo = gene.NCBIGeneInfo(taxid)
            return set(geneinfo.keys())

        def namematcher():
            matcher = genematcher.result()
            match = matcher.set_targets(ref_set.result())
            match.umatch = memoize(match.umatch)
            return match

        def map_unames():
            matcher = namematcher.result()
            query = list(filter(None, map(matcher.umatch, querynames)))
            reference = list(
                filter(None, map(matcher.umatch, ref_set.result())))
            return query, reference

        if self._nogenematching():
            if referenceGenes is None:
                ref_set = self._executor.submit(refset_null)
            else:
                ref_set = fulfill(referenceGenes)
        else:
            if referenceGenes == None:
                ref_set = self._executor.submit(refset_ncbi)
            else:
                ref_set = fulfill(referenceGenes)

        namematcher = self._executor.submit(namematcher)
        querynames = clusterGenes

        state = types.SimpleNamespace()
        state.query_set = clusterGenes
        state.reference_set = referenceGenes
        state.namematcher = namematcher
        state.query_count = len(set(clusterGenes))
        state.reference_count = (len(set(referenceGenes))
                                 if referenceGenes is not None else None)

        state.cancelled = False

        progress = methodinvoke(self, "_setProgress", (float, ))
        info = methodinvoke(self, "_setRunInfo", (str, ))

        @withtraceback
        def run():
            info("Loading data")
            match = namematcher.result()
            query, reference = map_unames()
            gscollections = collections.result()

            results = []
            info("Running enrichment")
            p = 0
            for i, gset in enumerate(gscollections):
                genes = set(filter(None, map(match.umatch, gset.genes)))
                enr = set_enrichment(genes, reference, query)
                results.append((gset, enr))

                if state.cancelled:
                    raise UserInteruptException

                pnew = int(100 * i / len(gscollections))
                if pnew != p:
                    progress(pnew)
                    p = pnew
            progress(100)
            info("")
            return query, reference, results

        task = Task(function=run)
        task.resultReady.connect(self.__on_enrichment_finished)
        task.exceptionReady.connect(self.__on_enrichment_failed)
        result = self._executor.submit(task)
        state.results = result

        self.state = state
        self._updatesummary()

    def __on_enrichment_failed(self, exception):
        if not isinstance(exception, UserInteruptException):
            print("ERROR:", exception, file=sys.stderr)
            print(exception._traceback, file=sys.stderr)

        self.progressBarFinished()
        self.setStatusMessage("")
        self.__state &= ~OWSetEnrichment.RunningEnrichment

    def __on_enrichment_finished(self, results):
        assert QThread.currentThread() is self.thread()
        self.__state &= ~OWSetEnrichment.RunningEnrichment

        query, reference, results = results

        if self.annotationsChartView.model():
            self.annotationsChartView.model().clear()

        nquery = len(query)
        nref = len(reference)
        maxcount = max((len(e.query_mapped) for _, e in results), default=1)
        maxrefcount = max((len(e.reference_mapped) for _, e in results),
                          default=1)
        nspaces = int(math.ceil(math.log10(maxcount or 1)))
        refspaces = int(math.ceil(math.log(maxrefcount or 1)))
        query_fmt = "%" + str(nspaces) + "s  (%.2f%%)"
        ref_fmt = "%" + str(refspaces) + "s  (%.2f%%)"

        def fmt_count(fmt, count, total):
            return fmt % (count, 100.0 * count / (total or 1))

        fmt_query_count = partial(fmt_count, query_fmt)
        fmt_ref_count = partial(fmt_count, ref_fmt)

        linkFont = QFont(self.annotationsChartView.viewOptions().font)
        linkFont.setUnderline(True)

        def item(value=None, tooltip=None, user=None):
            si = QStandardItem()
            if value is not None:
                si.setData(value, Qt.DisplayRole)
            if tooltip is not None:
                si.setData(tooltip, Qt.ToolTipRole)
            if user is not None:
                si.setData(user, Qt.UserRole)
            else:
                si.setData(value, Qt.UserRole)
            return si

        model = QStandardItemModel()
        model.setSortRole(Qt.UserRole)
        model.setHorizontalHeaderLabels([
            "Category", "Term", "Count", "Reference count", "p-value", "FDR",
            "Enrichment"
        ])
        for i, (gset, enrich) in enumerate(results):
            if len(enrich.query_mapped) == 0:
                continue
            nquery_mapped = len(enrich.query_mapped)
            nref_mapped = len(enrich.reference_mapped)

            row = [
                item(", ".join(gset.hierarchy)),
                item(gsname(gset), tooltip=gset.link),
                item(fmt_query_count(nquery_mapped, nquery),
                     tooltip=nquery_mapped,
                     user=nquery_mapped),
                item(fmt_ref_count(nref_mapped, nref),
                     tooltip=nref_mapped,
                     user=nref_mapped),
                item(fmtp(enrich.p_value), user=enrich.p_value),
                item(
                ),  # column 5, FDR, is computed in filterAnnotationsChartView
                item(enrich.enrichment_score,
                     tooltip="%.3f" % enrich.enrichment_score,
                     user=enrich.enrichment_score)
            ]
            row[0].geneset = gset
            row[0].enrichment = enrich
            row[1].setData(gset.link, gui.LinkRole)
            row[1].setFont(linkFont)
            row[1].setForeground(QColor(Qt.blue))

            model.appendRow(row)

        self.annotationsChartView.setModel(model)
        self.annotationsChartView.selectionModel().selectionChanged.connect(
            self.commit)

        if not model.rowCount():
            self.warning(0, "No enriched sets found.")
        else:
            self.warning(0)

        allnames = set(
            gsname(geneset) for geneset, (count, _, _, _) in results if count)

        allnames |= reduce(operator.ior,
                           (set(word_split(name)) for name in allnames), set())

        self.filterCompleter.setModel(None)
        self.completerModel = QStringListModel(sorted(allnames))
        self.filterCompleter.setModel(self.completerModel)

        if results:
            max_score = max(
                (e.enrichment_score
                 for _, e in results if np.isfinite(e.enrichment_score)),
                default=1)

            self.annotationsChartView.setItemDelegateForColumn(
                6, BarItemDelegate(self, scale=(0.0, max_score)))

        self.annotationsChartView.setItemDelegateForColumn(
            1, gui.LinkStyledItemDelegate(self.annotationsChartView))

        header = self.annotationsChartView.header()
        for i in range(model.columnCount()):
            sh = self.annotationsChartView.sizeHintForColumn(i)
            sh = max(sh, header.sectionSizeHint(i))
            self.annotationsChartView.setColumnWidth(i, max(min(sh, 300), 30))


#             self.annotationsChartView.resizeColumnToContents(i)

        self.filterAnnotationsChartView()

        self.progressBarFinished()
        self.setStatusMessage("")

    def _updatesummary(self):
        state = self.state
        if state is None:
            self.error(0, )
            self.warning(0)
            self.infoBox.setText("No data on input.\n")
            return

        text = "{.query_count} unique names on input\n".format(state)

        if state.results.done() and not state.results.exception():
            mapped, _, _ = state.results.result()
            ratio_mapped = (len(mapped) /
                            state.query_count if state.query_count else 0)
            text += ("%i (%.1f%%) gene names matched" %
                     (len(mapped), 100.0 * ratio_mapped))
        elif not state.results.done():
            text += "..."
        else:
            text += "<Error {}>".format(str(state.results.exception()))
        self.infoBox.setText(text)

        # TODO: warn on no enriched sets found (i.e no query genes
        # mapped to any set)

    def filterAnnotationsChartView(self, filterString=""):
        if self.__state & OWSetEnrichment.RunningEnrichment:
            return

        # TODO: Move filtering to a filter proxy model
        # TODO: Re-enable string search

        categories = set(", ".join(cat)
                         for cat, _ in self.selectedCategories())

        #         filterString = str(self.filterLineEdit.text()).lower()

        model = self.annotationsChartView.model()

        def ishidden(index):
            # Is item at index (row) hidden
            item = model.item(index)
            item_cat = item.data(Qt.DisplayRole)
            return item_cat not in categories

        hidemask = [ishidden(i) for i in range(model.rowCount())]

        # compute FDR according the selected categories
        pvals = [
            model.item(i, 4).data(Qt.UserRole)
            for i, hidden in enumerate(hidemask) if not hidden
        ]
        fdrs = utils.stats.FDR(pvals)

        # update FDR for the selected collections and apply filtering rules
        itemsHidden = []
        fdriter = iter(fdrs)
        for index, hidden in enumerate(hidemask):
            if not hidden:
                fdr = next(fdriter)
                pval = model.index(index, 4).data(Qt.UserRole)
                count = model.index(index, 2).data(Qt.ToolTipRole)

                hidden = (self.useMinCountFilter and count < self.minClusterCount) or \
                         (self.useMaxPValFilter and pval > self.maxPValue) or \
                         (self.useMaxFDRFilter and fdr > self.maxFDR)

                if not hidden:
                    fdr_item = model.item(index, 5)
                    fdr_item.setData(fmtpdet(fdr), Qt.ToolTipRole)
                    fdr_item.setData(fmtp(fdr), Qt.DisplayRole)
                    fdr_item.setData(fdr, Qt.UserRole)

            self.annotationsChartView.setRowHidden(index, QModelIndex(),
                                                   hidden)

            itemsHidden.append(hidden)

        if model.rowCount() and all(itemsHidden):
            self.information(0, "All sets were filtered out.")
        else:
            self.information(0)

        self._updatesummary()

    @Slot(float)
    def _setProgress(self, value):
        assert QThread.currentThread() is self.thread()
        self.progressBarSet(value, processEvents=None)

    @Slot(str)
    def _setRunInfo(self, text):
        self.setStatusMessage(text)

    def commit(self):
        if self.data is None or \
                self.__state & OWSetEnrichment.RunningEnrichment:
            return

        model = self.annotationsChartView.model()
        rows = self.annotationsChartView.selectionModel().selectedRows(0)
        selected = [model.item(index.row(), 0) for index in rows]
        mapped = reduce(operator.ior, (set(item.enrichment.query_mapped)
                                       for item in selected), set())
        assert self.state.namematcher.done()
        matcher = self.state.namematcher.result()

        axis = 1 if self.genesinrows else 0
        if axis == 1:
            mapped = [
                attr for attr in self.data.domain.attributes
                if matcher.umatch(attr.name) in mapped
            ]

            newdomain = Orange.data.Domain(mapped, self.data.domain.class_vars,
                                           self.data.domain.metas)
            data = self.data.from_table(newdomain, self.data)
        else:
            geneattr = self.geneAttrs[self.geneattr]
            selected = [
                i for i, ex in enumerate(self.data)
                if matcher.umatch(str(ex[geneattr])) in mapped
            ]
            data = self.data[selected]
        self.send("Data subset", data)

    def onDeleteWidget(self):
        if self.state is not None:
            self._cancelPending()
            self.state = None
        self._executor.shutdown(wait=False)
class OWGOEnrichmentAnalysis(widget.OWWidget):
    name = "GO Browser"
    description = "Enrichment analysis for Gene Ontology terms."
    icon = "../widgets/icons/GOBrowser.svg"
    priority = 2020

    inputs = [("Cluster Data", Orange.data.Table,
               "setDataset", widget.Single + widget.Default),
              ("Reference Data", Orange.data.Table,
               "setReferenceDataset")]

    outputs = [("Data on Selected Genes", Orange.data.Table),
               ("Data on Unselected Genes", Orange.data.Table),
               ("Data on Unknown Genes", Orange.data.Table),
               ("Enrichment Report", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    annotationIndex = settings.ContextSetting(0)
    geneAttrIndex = settings.ContextSetting(0)
    useAttrNames = settings.ContextSetting(False)
    geneMatcherSettings = settings.Setting([True, False, False, False])
    useReferenceDataset = settings.Setting(False)
    aspectIndex = settings.Setting(0)

    useEvidenceType = settings.Setting(
        {et: True for et in go.evidenceTypesOrdered})

    filterByNumOfInstances = settings.Setting(False)
    minNumOfInstances = settings.Setting(1)
    filterByPValue = settings.Setting(True)
    maxPValue = settings.Setting(0.2)
    filterByPValue_nofdr = settings.Setting(False)
    maxPValue_nofdr = settings.Setting(0.01)
    probFunc = settings.Setting(0)

    selectionDirectAnnotation = settings.Setting(0)
    selectionDisjoint = settings.Setting(0)
    selectionAddTermAsClass = settings.Setting(0)

    Ready, Initializing, Running = 0, 1, 2

    def __init__(self, parent=None):
        super().__init__(self, parent)

        self.clusterDataset = None
        self.referenceDataset = None
        self.ontology = None
        self.annotations = None
        self.loadedAnnotationCode = "---"
        self.treeStructRootKey = None
        self.probFunctions = [stats.Binomial(), stats.Hypergeometric()]
        self.selectedTerms = []

        self.selectionChanging = 0
        self.__state = OWGOEnrichmentAnalysis.Initializing

        self.annotationCodes = []

        #############
        ## GUI
        #############
        self.tabs = gui.tabWidget(self.controlArea)
        ## Input tab
        self.inputTab = gui.createTabPage(self.tabs, "Input")
        box = gui.widgetBox(self.inputTab, "Info")
        self.infoLabel = gui.widgetLabel(box, "No data on input\n")

        gui.button(box, self, "Ontology/Annotation Info",
                   callback=self.ShowInfo,
                   tooltip="Show information on loaded ontology and annotations")

        box = gui.widgetBox(self.inputTab, "Organism")
        self.annotationComboBox = gui.comboBox(
            box, self, "annotationIndex", items=self.annotationCodes,
            callback=self._updateEnrichment, tooltip="Select organism")

        genebox = gui.widgetBox(self.inputTab, "Gene Names")
        self.geneAttrIndexCombo = gui.comboBox(
            genebox, self, "geneAttrIndex", callback=self._updateEnrichment,
            tooltip="Use this attribute to extract gene names from input data")
        self.geneAttrIndexCombo.setDisabled(self.useAttrNames)

        cb = gui.checkBox(genebox, self, "useAttrNames", "Use column names",
                          tooltip="Use column names for gene names",
                          callback=self._updateEnrichment)
        cb.toggled[bool].connect(self.geneAttrIndexCombo.setDisabled)

        gui.button(genebox, self, "Gene matcher settings",
                   callback=self.UpdateGeneMatcher,
                   tooltip="Open gene matching settings dialog")

        self.referenceRadioBox = gui.radioButtonsInBox(
            self.inputTab, self, "useReferenceDataset",
            ["Entire genome", "Reference set (input)"],
            tooltips=["Use entire genome for reference",
                      "Use genes from Referece Examples input signal as reference"],
            box="Reference", callback=self._updateEnrichment)

        self.referenceRadioBox.buttons[1].setDisabled(True)
        gui.radioButtonsInBox(
            self.inputTab, self, "aspectIndex",
            ["Biological process", "Cellular component", "Molecular function"],
            box="Aspect", callback=self._updateEnrichment)

        ## Filter tab
        self.filterTab = gui.createTabPage(self.tabs, "Filter")
        box = gui.widgetBox(self.filterTab, "Filter GO Term Nodes")
        gui.checkBox(box, self, "filterByNumOfInstances", "Genes",
                     callback=self.FilterAndDisplayGraph, 
                     tooltip="Filter by number of input genes mapped to a term")
        ibox = gui.indentedBox(box)
        gui.spin(ibox, self, 'minNumOfInstances', 1, 100,
                 step=1, label='#:', labelWidth=15,
                 callback=self.FilterAndDisplayGraph,
                 callbackOnReturn=True,
                 tooltip="Min. number of input genes mapped to a term")

        gui.checkBox(box, self, "filterByPValue_nofdr", "p-value",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term p-value")

        gui.doubleSpin(gui.indentedBox(box), self, 'maxPValue_nofdr', 1e-8, 1,
                       step=1e-8,  label='p:', labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        #use filterByPValue for FDR, as it was the default in prior versions
        gui.checkBox(box, self, "filterByPValue", "FDR",
                     callback=self.FilterAndDisplayGraph,
                     tooltip="Filter by term FDR")
        gui.doubleSpin(gui.indentedBox(box), self, 'maxPValue', 1e-8, 1,
                       step=1e-8,  label='p:', labelWidth=15,
                       callback=self.FilterAndDisplayGraph,
                       callbackOnReturn=True,
                       tooltip="Max term p-value")

        box = gui.widgetBox(box, "Significance test")

        gui.radioButtonsInBox(box, self, "probFunc", ["Binomial", "Hypergeometric"],
                              tooltips=["Use binomial distribution test",
                                        "Use hypergeometric distribution test"],
                              callback=self._updateEnrichment)
        box = gui.widgetBox(self.filterTab, "Evidence codes in annotation",
                              addSpace=True)
        self.evidenceCheckBoxDict = {}
        for etype in go.evidenceTypesOrdered:
            ecb = QCheckBox(
                etype, toolTip=go.evidenceTypes[etype],
                checked=self.useEvidenceType[etype])
            ecb.toggled.connect(self.__on_evidenceChanged)
            box.layout().addWidget(ecb)
            self.evidenceCheckBoxDict[etype] = ecb

        ## Select tab
        self.selectTab = gui.createTabPage(self.tabs, "Select")
        box = gui.radioButtonsInBox(
            self.selectTab, self, "selectionDirectAnnotation",
            ["Directly or Indirectly", "Directly"],
            box="Annotated genes",
            callback=self.ExampleSelection)

        box = gui.widgetBox(self.selectTab, "Output", addSpace=True)
        gui.radioButtonsInBox(
            box, self, "selectionDisjoint",
            btnLabels=["All selected genes",
                       "Term-specific genes",
                       "Common term genes"],
            tooltips=["Outputs genes annotated to all selected GO terms",
                      "Outputs genes that appear in only one of selected GO terms", 
                      "Outputs genes common to all selected GO terms"],
            callback=[self.ExampleSelection,
                      self.UpdateAddClassButton])

        self.addClassCB = gui.checkBox(
            box, self, "selectionAddTermAsClass", "Add GO Term as class",
            callback=self.ExampleSelection)

        # ListView for DAG, and table for significant GOIDs
        self.DAGcolumns = ['GO term', 'Cluster', 'Reference', 'p-value',
                           'FDR', 'Genes', 'Enrichment']

        self.splitter = QSplitter(Qt.Vertical, self.mainArea)
        self.mainArea.layout().addWidget(self.splitter)

        # list view
        self.listView = GOTreeWidget(self.splitter)
        self.listView.setSelectionMode(QTreeView.ExtendedSelection)
        self.listView.setAllColumnsShowFocus(1)
        self.listView.setColumnCount(len(self.DAGcolumns))
        self.listView.setHeaderLabels(self.DAGcolumns)

        self.listView.header().setSectionsClickable(True)
        self.listView.header().setSortIndicatorShown(True)
        self.listView.setSortingEnabled(True)
        self.listView.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))
        self.listView.setRootIsDecorated(True)

        self.listView.itemSelectionChanged.connect(self.ViewSelectionChanged)

        # table of significant GO terms
        self.sigTerms = QTreeWidget(self.splitter)
        self.sigTerms.setColumnCount(len(self.DAGcolumns))
        self.sigTerms.setHeaderLabels(self.DAGcolumns)
        self.sigTerms.setSortingEnabled(True)
        self.sigTerms.setSelectionMode(QTreeView.ExtendedSelection)
        self.sigTerms.setItemDelegateForColumn(
            6, EnrichmentColumnItemDelegate(self))

        self.sigTerms.itemSelectionChanged.connect(self.TableSelectionChanged)

        self.sigTableTermsSorted = []
        self.graph = {}

        self.inputTab.layout().addStretch(1)
        self.filterTab.layout().addStretch(1)
        self.selectTab.layout().addStretch(1)

        self.setBlocking(True)
        self._executor = ThreadExecutor()
        self._init = EnsureDownloaded(
            [(taxonomy.Taxonomy.DOMAIN, taxonomy.Taxonomy.FILENAME),
             ("GO", "taxonomy.pickle")]
        )
        self._init.finished.connect(self.__initialize_finish)
        self._executor.submit(self._init)


    def sizeHint(self):
        return QSize(1000, 700)

    def __initialize_finish(self):
        self.setBlocking(False)

        try:
            self.annotationFiles = listAvailable()
        except ConnectTimeout:
            self.error(2, "Internet connection error, unable to load data. " + \
                          "Check connection and create a new GO Browser widget.")
            self.filterTab.setEnabled(False)
            self.inputTab.setEnabled(False)
            self.selectTab.setEnabled(False)
            self.listView.setEnabled(False)
            self.sigTerms.setEnabled(False)
        else:
            self.annotationCodes = sorted(self.annotationFiles.keys())
            self.annotationComboBox.clear()
            self.annotationComboBox.addItems(self.annotationCodes)
            self.annotationComboBox.setCurrentIndex(self.annotationIndex)
            self.__state = OWGOEnrichmentAnalysis.Ready

    def __on_evidenceChanged(self):
        for etype, cb in self.evidenceCheckBoxDict.items():
            self.useEvidenceType[etype] = cb.isChecked()
        self._updateEnrichment()

    def UpdateGeneMatcher(self):
        """Open the Gene matcher settings dialog."""
        dialog = GeneMatcherDialog(self, defaults=self.geneMatcherSettings, modal=True)
        if dialog.exec_() != QDialog.Rejected:
            self.geneMatcherSettings = [getattr(dialog, item[0]) for item in dialog.items]
            if self.annotations:
                self.SetGeneMatcher()
                self._updateEnrichment()

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.warning(0)
        self.warning(1)
        self.geneAttrIndexCombo.clear()
        self.ClearGraph()

        self.send("Data on Selected Genes", None)
        self.send("Data on Unselected Genes", None)
        self.send("Data on Unknown Genes", None)
        self.send("Enrichment Report", None)

    def setDataset(self, data=None):
        if self.__state == OWGOEnrichmentAnalysis.Initializing:
            self.__initialize_finish()

        self.closeContext()
        self.clear()
        self.clusterDataset = data

        if data is not None:
            domain = data.domain
            allvars = domain.variables + domain.metas
            self.candidateGeneAttrs = [var for var in allvars if isstring(var)]

            self.geneAttrIndexCombo.clear()
            for var in self.candidateGeneAttrs:
                self.geneAttrIndexCombo.addItem(*gui.attributeItem(var))
            taxid = data_hints.get_hint(data, "taxid", "")
            code = None
            try:
                code = go.from_taxid(taxid)
            except KeyError:
                pass
            except Exception as ex:
                print(ex)

            if code is not None:
                filename = "gene_association.%s.tar.gz" % code
                if filename in self.annotationFiles.values():
                    self.annotationIndex = \
                            [i for i, name in enumerate(self.annotationCodes) \
                             if self.annotationFiles[name] == filename].pop()

            self.useAttrNames = data_hints.get_hint(data, "genesinrows",
                                                    self.useAttrNames)
            self.openContext(data)

            self.geneAttrIndex = min(self.geneAttrIndex,
                                     len(self.candidateGeneAttrs) - 1)
            if len(self.candidateGeneAttrs) == 0:
                self.useAttrNames = True
                self.geneAttrIndex = -1
            elif self.geneAttrIndex < len(self.candidateGeneAttrs):
                self.geneAttrIndex = len(self.candidateGeneAttrs) - 1

            self._updateEnrichment()

    def setReferenceDataset(self, data=None):
        self.referenceDataset = data
        self.referenceRadioBox.buttons[1].setDisabled(not bool(data))
        self.referenceRadioBox.buttons[1].setText("Reference set")
        if self.clusterDataset is not None and self.useReferenceDataset:
            self.useReferenceDataset = 0 if not data else 1
            graph = self.Enrichment()
            self.SetGraph(graph)
        elif self.clusterDataset:
            self.__updateReferenceSetButton()

    def handleNewSignals(self):
        super().handleNewSignals()

    def _updateEnrichment(self):
        if self.clusterDataset is not None and \
                self.__state == OWGOEnrichmentAnalysis.Ready:
            pb = gui.ProgressBar(self, 100)
            self.Load(pb=pb)
            graph = self.Enrichment(pb=pb)
            self.FilterUnknownGenes()
            self.SetGraph(graph)

    def __updateReferenceSetButton(self):
        allgenes, refgenes = None, None
        if self.referenceDataset:
            try:
                allgenes = self.genesFromTable(self.referenceDataset)
            except Exception:
                allgenes = []
            refgenes, unknown = self.FilterAnnotatedGenes(allgenes)
        self.referenceRadioBox.buttons[1].setDisabled(not bool(allgenes))
        self.referenceRadioBox.buttons[1].setText("Reference set " + ("(%i genes, %i matched)" % (len(allgenes), len(refgenes)) if allgenes and refgenes else ""))

    def genesFromTable(self, data):
        if self.useAttrNames:
            genes = [v.name for v in data.domain.variables]
        else:
            attr = self.candidateGeneAttrs[min(self.geneAttrIndex, len(self.candidateGeneAttrs) - 1)]
            genes = [str(ex[attr]) for ex in data if not numpy.isnan(ex[attr])]
            if any("," in gene for gene in genes):
                self.information(0, "Separators detected in gene names. Assuming multiple genes per example.")
                genes = reduce(operator.iadd, (genes.split(",") for genes in genes), [])
        return genes

    def FilterAnnotatedGenes(self, genes):
        matchedgenes = self.annotations.get_gene_names_translator(genes).values()
        return matchedgenes, [gene for gene in genes if gene not in matchedgenes]

    def FilterUnknownGenes(self):
        if not self.useAttrNames and self.candidateGeneAttrs:
            geneAttr = self.candidateGeneAttrs[min(self.geneAttrIndex, len(self.candidateGeneAttrs)-1)]
            indices = []
            for i, ex in enumerate(self.clusterDataset):
                if not any(self.annotations.genematcher.match(n.strip()) for n in str(ex[geneAttr]).split(",")):
                    indices.append(i)
            if indices:
                data = self.clusterDataset[indices]
            else:
                data = None
            self.send("Data on Unknown Genes", data)
        else:
            self.send("Data on Unknown Genes", None)

    def Load(self, pb=None):

        if self.__state == OWGOEnrichmentAnalysis.Ready:
            go_files, tax_files = serverfiles.listfiles("GO"), serverfiles.listfiles("Taxonomy")
            calls = []
            pb, finish = (gui.ProgressBar(self, 0), True) if pb is None else (pb, False)
            count = 0
            if not tax_files:
                calls.append(("Taxonomy", "ncbi_taxnomy.tar.gz"))
                count += 1
            org = self.annotationCodes[min(self.annotationIndex, len(self.annotationCodes)-1)]
            if org != self.loadedAnnotationCode:
                count += 1
                if self.annotationFiles[org] not in go_files:
                    calls.append(("GO", self.annotationFiles[org]))
                    count += 1

            if "gene_ontology_edit.obo.tar.gz" not in go_files:
                calls.append(("GO", "gene_ontology_edit.obo.tar.gz"))
                count += 1
            if not self.ontology:
                count += 1
            pb.iter += count * 100

            for args in calls:
                serverfiles.localpath_download(*args, **dict(callback=pb.advance))

            i = len(calls)
            if not self.ontology:
                self.ontology = go.Ontology(progress_callback=lambda value: pb.advance())
                i += 1

            if org != self.loadedAnnotationCode:
                self.annotations = None
                gc.collect()  # Force run garbage collection
                code = self.annotationFiles[org].split(".")[-3]
                self.annotations = go.Annotations(code, genematcher=gene.GMDirect(), progress_callback=lambda value: pb.advance())
                i += 1
                self.loadedAnnotationCode = org
                count = defaultdict(int)
                geneSets = defaultdict(set)

                for anno in self.annotations.annotations:
                    count[anno.evidence] += 1
                    geneSets[anno.evidence].add(anno.geneName)
                for etype in go.evidenceTypesOrdered:
                    ecb = self.evidenceCheckBoxDict[etype]
                    ecb.setEnabled(bool(count[etype]))
                    ecb.setText(etype + ": %i annots(%i genes)" % (count[etype], len(geneSets[etype])))
            if finish:
                pb.finish()

    def SetGeneMatcher(self):
        if self.annotations:
            taxid = self.annotations.taxid
            matchers = []
            for matcher, use in zip([gene.GMGO, gene.GMKEGG, gene.GMNCBI, gene.GMAffy], self.geneMatcherSettings):
                if use:
                    try:
                        if taxid == "352472":
                            matchers.extend([matcher(taxid), gene.GMDicty(),
                                             [matcher(taxid), gene.GMDicty()]])
                            # The reason machers are duplicated is that we want `matcher` or `GMDicty` to
                            # match genes by them self if possible. Only use the joint matcher if they fail.   
                        else:
                            matchers.append(matcher(taxid))
                    except Exception as ex:
                        print(ex)
            self.annotations.genematcher = gene.matcher(matchers)
            self.annotations.genematcher.set_targets(self.annotations.gene_names)

    def Enrichment(self, pb=None):
        assert self.clusterDataset is not None

        pb = gui.ProgressBar(self, 100) if pb is None else pb
        if not self.annotations.ontology:
            self.annotations.ontology = self.ontology

        if isinstance(self.annotations.genematcher, gene.GMDirect):
            self.SetGeneMatcher()
        self.error(1)
        self.warning([0, 1])

        if self.useAttrNames:
            clusterGenes = [v.name for v in self.clusterDataset.domain.attributes]
            self.information(0)
        elif 0 <= self.geneAttrIndex < len(self.candidateGeneAttrs):
            geneAttr = self.candidateGeneAttrs[self.geneAttrIndex]
            clusterGenes = [str(ex[geneAttr]) for ex in self.clusterDataset
                            if not numpy.isnan(ex[geneAttr])]
            if any("," in gene for gene in clusterGenes):
                self.information(0, "Separators detected in cluster gene names. Assuming multiple genes per example.")
                clusterGenes = reduce(operator.iadd, (genes.split(",") for genes in clusterGenes), [])
            else:
                self.information(0)
        else:
            self.error(1, "Failed to extract gene names from input dataset!")
            return {}

        genesSetCount = len(set(clusterGenes))

        self.clusterGenes = clusterGenes = self.annotations.get_gene_names_translator(clusterGenes).values()

        self.infoLabel.setText("%i unique genes on input\n%i (%.1f%%) genes with known annotations" % (genesSetCount, len(clusterGenes), 100.0*len(clusterGenes)/genesSetCount if genesSetCount else 0.0))

        referenceGenes = None
        if not self.useReferenceDataset or self.referenceDataset is None:
            self.information(2)
            self.information(1)
            referenceGenes = self.annotations.gene_names

        elif self.referenceDataset is not None:
            if self.useAttrNames:
                referenceGenes = [v.name for v in self.referenceDataset.domain.attributes]
                self.information(1)
            elif geneAttr in (self.referenceDataset.domain.variables +
                              self.referenceDataset.domain.metas):
                referenceGenes = [str(ex[geneAttr]) for ex in self.referenceDataset
                                  if not numpy.isnan(ex[geneAttr])]
                if any("," in gene for gene in clusterGenes):
                    self.information(1, "Separators detected in reference gene names. Assuming multiple genes per example.")
                    referenceGenes = reduce(operator.iadd, (genes.split(",") for genes in referenceGenes), [])
                else:
                    self.information(1)
            else:
                self.information(1)
                referenceGenes = None

            if referenceGenes is None:
                referenceGenes = list(self.annotations.gene_names)
                self.referenceRadioBox.buttons[1].setText("Reference set")
                self.referenceRadioBox.buttons[1].setDisabled(True)
                self.information(2, "Unable to extract gene names from reference dataset. Using entire genome for reference")
                self.useReferenceDataset = 0
            else:
                refc = len(referenceGenes)
                referenceGenes = self.annotations.get_gene_names_translator(referenceGenes).values()
                self.referenceRadioBox.buttons[1].setText("Reference set (%i genes, %i matched)" % (refc, len(referenceGenes)))
                self.referenceRadioBox.buttons[1].setDisabled(False)
                self.information(2)
        else:
            self.useReferenceDataset = 0

        if not referenceGenes:
            self.error(1, "No valid reference set")
            return {}

        self.referenceGenes = referenceGenes
        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
                evidences.append(etype)
        aspect = ["P", "C", "F"][self.aspectIndex]

        if clusterGenes:
            self.terms = terms = self.annotations.get_enriched_terms(
                clusterGenes, referenceGenes, evidences, aspect=aspect,
                prob=self.probFunctions[self.probFunc], use_fdr=False,
                progress_callback=lambda value: pb.advance())
            ids = []
            pvals = []
            for i, d in self.terms.items():
                ids.append(i)
                pvals.append(d[1])
            for i, fdr in zip(ids, stats.FDR(pvals)):  # save FDR as the last part of the tuple
                terms[i] = tuple(list(terms[i]) + [ fdr ])

        else:
            self.terms = terms = {}
        if not self.terms:
            self.warning(0, "No enriched terms found.")
        else:
            self.warning(0)

        pb.finish()
        self.treeStructDict = {}
        ids = self.terms.keys()

        self.treeStructRootKey = None

        parents = {}
        for id in ids:
            parents[id] = set([term for _, term in self.ontology[id].related])

        children = {}
        for term in self.terms:
            children[term] = set([id for id in ids if term in parents[id]])

        for term in self.terms:
            self.treeStructDict[term] = TreeNode(self.terms[term], children[term])
            if not self.ontology[term].related and not getattr(self.ontology[term], "is_obsolete", False):
                self.treeStructRootKey = term
        return terms

    def FilterGraph(self, graph):
        if self.filterByPValue_nofdr:
            graph = go.filterByPValue(graph, self.maxPValue_nofdr)
        if self.filterByPValue: #FDR
            graph = dict(filter(lambda item: item[1][3] <= self.maxPValue, graph.items()))
        if self.filterByNumOfInstances:
            graph = dict(filter(lambda item: len(item[1][0]) >= self.minNumOfInstances, graph.items()))
        return graph

    def FilterAndDisplayGraph(self):
        if self.clusterDataset:
            self.graph = self.FilterGraph(self.originalGraph)
            if self.originalGraph and not self.graph:
                self.warning(1, "All found terms were filtered out.")
            else:
                self.warning(1)
            self.ClearGraph()
            self.DisplayGraph()

    def SetGraph(self, graph=None):
        self.originalGraph = graph
        if graph:
            self.FilterAndDisplayGraph()
        else:
            self.graph = {}
            self.ClearGraph()

    def ClearGraph(self):
        self.listView.clear()
        self.listViewItems=[]
        self.sigTerms.clear()

    def DisplayGraph(self):
        fromParentDict = {}
        self.termListViewItemDict = {}
        self.listViewItems = []
        enrichment = lambda t: len(t[0]) / t[2] * (len(self.referenceGenes) / len(self.clusterGenes))
        maxFoldEnrichment = max([enrichment(term) for term in self.graph.values()] or [1])

        def addNode(term, parent, parentDisplayNode):
            if (parent, term) in fromParentDict:
                return
            if term in self.graph:
                displayNode = GOTreeWidgetItem(self.ontology[term], self.graph[term], len(self.clusterGenes), len(self.referenceGenes), maxFoldEnrichment, parentDisplayNode)
                displayNode.goId = term
                self.listViewItems.append(displayNode)
                if term in self.termListViewItemDict:
                    self.termListViewItemDict[term].append(displayNode)
                else:
                    self.termListViewItemDict[term] = [displayNode]
                fromParentDict[(parent, term)] = True
                parent = term
            else:
                displayNode = parentDisplayNode

            for c in self.treeStructDict[term].children:
                addNode(c, parent, displayNode)

        if self.treeStructDict:
            addNode(self.treeStructRootKey, None, self.listView)

        terms = self.graph.items()
        terms = sorted(terms, key=lambda item: item[1][1])
        self.sigTableTermsSorted = [t[0] for t in terms]

        self.sigTerms.clear()
        for i, (t_id, (genes, p_value, refCount, fdr)) in enumerate(terms):
            item = GOTreeWidgetItem(self.ontology[t_id],
                                    (genes, p_value, refCount, fdr),
                                    len(self.clusterGenes),
                                    len(self.referenceGenes),
                                    maxFoldEnrichment,
                                    self.sigTerms)
            item.goId = t_id

        self.listView.expandAll()
        for i in range(5):
            self.listView.resizeColumnToContents(i)
            self.sigTerms.resizeColumnToContents(i)
        self.sigTerms.resizeColumnToContents(6)
        width = min(self.listView.columnWidth(0), 350)
        self.listView.setColumnWidth(0, width)
        self.sigTerms.setColumnWidth(0, width)

        # Create and send the enrichemnt report table.
        termsDomain = Orange.data.Domain(
            [], [],
            # All is meta!
            [Orange.data.StringVariable("GO Term Id"),
             Orange.data.StringVariable("GO Term Name"),
             Orange.data.ContinuousVariable("Cluster Frequency"),
             Orange.data.ContinuousVariable("Genes in Cluster", number_of_decimals=0),
             Orange.data.ContinuousVariable("Reference Frequency"),
             Orange.data.ContinuousVariable("Genes in Reference", number_of_decimals=0),
             Orange.data.ContinuousVariable("p-value"),
             Orange.data.ContinuousVariable("FDR"),
             Orange.data.ContinuousVariable("Enrichment"),
             Orange.data.StringVariable("Genes")])

        terms = [[t_id,
                  self.ontology[t_id].name,
                  len(genes) / len(self.clusterGenes),
                  len(genes),
                  r_count / len(self.referenceGenes),
                  r_count,
                  p_value,
                  fdr,
                  len(genes) / len(self.clusterGenes) * \
                  len(self.referenceGenes) / r_count,
                  ",".join(genes)
                  ]
                 for t_id, (genes, p_value, r_count, fdr) in terms]

        if terms:
            X = numpy.empty((len(terms), 0))
            M = numpy.array(terms, dtype=object)
            termsTable = Orange.data.Table.from_numpy(termsDomain, X, metas=M)
        else:
            termsTable = Orange.data.Table(termsDomain)
        self.send("Enrichment Report", termsTable)

    def ViewSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selected = self.listView.selectedItems()
        self.selectedTerms = list(set([lvi.term.id for lvi in selected]))
        self.ExampleSelection()
        self.selectionChanging = 0

    def TableSelectionChanged(self):
        if self.selectionChanging:
            return

        self.selectionChanging = 1
        self.selectedTerms = []
        selectedIds = set([self.sigTerms.itemFromIndex(index).goId for index in self.sigTerms.selectedIndexes()])

        for i in range(self.sigTerms.topLevelItemCount()):
            item = self.sigTerms.topLevelItem(i)
            selected = item.goId in selectedIds
            term = item.goId

            if selected:
                self.selectedTerms.append(term)

            for lvi in self.termListViewItemDict[term]:
                try:
                    lvi.setSelected(selected)
                    if selected:
                        lvi.setExpanded(True)
                except RuntimeError:  # Underlying C/C++ object deleted
                    pass

        self.ExampleSelection()
        self.selectionChanging = 0

    def UpdateAddClassButton(self):
        self.addClassCB.setEnabled(self.selectionDisjoint == 1)

    def ExampleSelection(self):
        self.commit()

    def commit(self):
        if self.clusterDataset is None:
            return

        terms = set(self.selectedTerms)
        genes = reduce(operator.ior,
                       (set(self.graph[term][0]) for term in terms), set())

        evidences = []
        for etype in go.evidenceTypesOrdered:
            if self.useEvidenceType[etype]:
#             if getattr(self, "useEvidence" + etype):
                evidences.append(etype)
        allTerms = self.annotations.get_annotated_terms(
            genes, direct_annotation_only=self.selectionDirectAnnotation,
            evidence_codes=evidences)

        if self.selectionDisjoint > 0:
            count = defaultdict(int)
            for term in self.selectedTerms:
                for g in allTerms.get(term, []):
                    count[g] += 1
            ccount = 1 if self.selectionDisjoint == 1 else len(self.selectedTerms)
            selectedGenes = [gene for gene, c in count.items()
                             if c == ccount and gene in genes]
        else:
            selectedGenes = reduce(
                operator.ior,
                (set(allTerms.get(term, [])) for term in self.selectedTerms),
                set())

        if self.useAttrNames:
            vars = [self.clusterDataset.domain[gene]
                    for gene in set(selectedGenes)]
            domain = Orange.data.Domain(
                vars, self.clusterDataset.domain.class_vars,
                self.clusterDataset.domain.metas)
            newdata = self.clusterDataset.from_table(domain, self.clusterDataset)

            self.send("Data on Selected Genes", newdata)
            self.send("Data on Unselected Genes", None)
        elif self.candidateGeneAttrs:
            selectedExamples = []
            unselectedExamples = []

            geneAttr = self.candidateGeneAttrs[min(self.geneAttrIndex, len(self.candidateGeneAttrs)-1)]

            if self.selectionDisjoint == 1:
                goVar = Orange.data.DiscreteVariable(
                    "GO Term", values=list(self.selectedTerms))
                newDomain = Orange.data.Domain(
                    self.clusterDataset.domain.variables, goVar,
                    self.clusterDataset.domain.metas)
                goColumn = []
            for i, ex in enumerate(self.clusterDataset):
                if not numpy.isnan(ex[geneAttr]) and any(gene in selectedGenes for gene in str(ex[geneAttr]).split(",")):
                    if self.selectionDisjoint == 1 and self.selectionAddTermAsClass:
                        terms = filter(lambda term: any(gene in self.graph[term][0] for gene in str(ex[geneAttr]).split(",")) , self.selectedTerms)
                        term = sorted(terms)[0]
                        goColumn.append(goVar.values.index(term))
                    selectedExamples.append(i)
                else:
                    unselectedExamples.append(i)

            if selectedExamples:
                selectedExamples = self.clusterDataset[selectedExamples]
                if self.selectionDisjoint == 1 and self.selectionAddTermAsClass:
                    selectedExamples = Orange.data.Table.from_table(newDomain, selectedExamples)
                    view, issparse = selectedExamples.get_column_view(goVar)
                    assert not issparse
                    view[:] = goColumn
            else:
                selectedExamples = None

            if unselectedExamples:
                unselectedExamples = self.clusterDataset[unselectedExamples]
            else:
                unselectedExamples = None

            self.send("Data on Selected Genes", selectedExamples)
            self.send("Data on Unselected Genes", unselectedExamples)

    def ShowInfo(self):
        dialog = QDialog(self)
        dialog.setModal(False)
        dialog.setLayout(QVBoxLayout())
        label = QLabel(dialog)
        label.setText("Ontology:\n" + self.ontology.header
                      if self.ontology else "Ontology not loaded!")
        dialog.layout().addWidget(label)

        label = QLabel(dialog)
        label.setText("Annotations:\n" + self.annotations.header.replace("!", "")
                      if self.annotations else "Annotations not loaded!")
        dialog.layout().addWidget(label)
        dialog.show()

    def onDeleteWidget(self):
        """Called before the widget is removed from the canvas.
        """
        self.annotations = None
        self.ontology = None
        gc.collect()  # Force collection
Ejemplo n.º 41
0
class OWLearningCurveC(widget.OWWidget):
    name = "Learning Curve (C)"
    description = ("Takes a dataset and a set of learners and shows a "
                   "learning curve in a table")
    icon = "icons/LearningCurve.svg"
    priority = 1010

    inputs = [
        ("Data", Orange.data.Table, "set_dataset", widget.Default),
        ("Test Data", Orange.data.Table, "set_testdataset"),
        (
            "Learner",
            Orange.classification.Learner,
            "set_learner",
            widget.Multiple + widget.Default,
        ),
    ]

    #: cross validation folds
    folds = settings.Setting(5)
    #: points in the learning curve
    steps = settings.Setting(10)
    #: index of the selected scoring function
    scoringF = settings.Setting(0)
    #: compute curve on any change of parameters
    commitOnChange = settings.Setting(True)

    def __init__(self):
        super().__init__()

        # sets self.curvePoints, self.steps equidistant points from
        # 1/self.steps to 1
        self.updateCurvePoints()

        self.scoring = [
            ("Classification Accuracy", Orange.evaluation.scoring.CA),
            ("AUC", Orange.evaluation.scoring.AUC),
            ("Precision", Orange.evaluation.scoring.Precision),
            ("Recall", Orange.evaluation.scoring.Recall),
        ]
        #: input data on which to construct the learning curve
        self.data = None
        #: optional test data
        self.testdata = None
        #: A {input_id: Learner} mapping of current learners from input channel
        self.learners = OrderedDict()
        #: A {input_id: List[Results]} mapping of input id to evaluation
        #: results list, one for each curve point
        self.results = OrderedDict()
        #: A {input_id: List[float]} mapping of input id to learning curve
        #: point scores
        self.curves = OrderedDict()

        # [start-snippet-3]
        #: The current evaluating task (if any)
        self._task = None  # type: Optional[Task]
        #: An executor we use to submit learner evaluations into a thread pool
        self._executor = ThreadExecutor()
        # [end-snippet-3]

        # GUI
        box = gui.widgetBox(self.controlArea, "Info")
        self.infoa = gui.widgetLabel(box, "No data on input.")
        self.infob = gui.widgetLabel(box, "No learners.")

        gui.separator(self.controlArea)

        box = gui.widgetBox(self.controlArea, "Evaluation Scores")
        gui.comboBox(
            box,
            self,
            "scoringF",
            items=[x[0] for x in self.scoring],
            callback=self._invalidate_curves,
        )

        gui.separator(self.controlArea)

        box = gui.widgetBox(self.controlArea, "Options")
        gui.spin(
            box,
            self,
            "folds",
            2,
            100,
            step=1,
            label="Cross validation folds:  ",
            keyboardTracking=False,
            callback=lambda: self._invalidate_results()
            if self.commitOnChange else None,
        )
        gui.spin(
            box,
            self,
            "steps",
            2,
            100,
            step=1,
            label="Learning curve points:  ",
            keyboardTracking=False,
            callback=[
                self.updateCurvePoints,
                lambda: self._invalidate_results()
                if self.commitOnChange else None,
            ],
        )
        gui.checkBox(box, self, "commitOnChange",
                     "Apply setting on any change")
        self.commitBtn = gui.button(box,
                                    self,
                                    "Apply Setting",
                                    callback=self._invalidate_results,
                                    disabled=True)

        gui.rubber(self.controlArea)

        # table widget
        self.table = gui.table(self.mainArea,
                               selectionMode=QTableWidget.NoSelection)

    ##########################################################################
    # slots: handle input signals

    def set_dataset(self, data):
        """Set the input train dataset."""
        # Clear all results/scores
        for id in list(self.results):
            self.results[id] = None
        for id in list(self.curves):
            self.curves[id] = None

        self.data = data

        if data is not None:
            self.infoa.setText("%d instances in input dataset" % len(data))
        else:
            self.infoa.setText("No data on input.")

        self.commitBtn.setEnabled(self.data is not None)

    def set_testdataset(self, testdata):
        """Set a separate test dataset."""
        # Clear all results/scores
        for id in list(self.results):
            self.results[id] = None
        for id in list(self.curves):
            self.curves[id] = None

        self.testdata = testdata

    def set_learner(self, learner, id):
        """Set the input learner for channel id."""
        if id in self.learners:
            if learner is None:
                # remove a learner and corresponding results
                del self.learners[id]
                del self.results[id]
                del self.curves[id]
            else:
                # update/replace a learner on a previously connected link
                self.learners[id] = learner
                # invalidate the cross-validation results and curve scores
                # (will be computed/updated in `_update`)
                self.results[id] = None
                self.curves[id] = None
        else:
            if learner is not None:
                self.learners[id] = learner
                # initialize the cross-validation results and curve scores
                # (will be computed/updated in `_update`)
                self.results[id] = None
                self.curves[id] = None

        if len(self.learners):
            self.infob.setText("%d learners on input." % len(self.learners))
        else:
            self.infob.setText("No learners.")

        self.commitBtn.setEnabled(len(self.learners))

    # [start-snippet-4]
    def handleNewSignals(self):
        self._update()

    # [end-snippet-4]

    def _invalidate_curves(self):
        if self.data is not None:
            self._update_curve_points()
        self._update_table()

    def _invalidate_results(self):
        for id in self.learners:
            self.curves[id] = None
            self.results[id] = None
        self._update()

    # [start-snippet-5]
    def _update(self):
        if self._task is not None:
            # First make sure any pending tasks are cancelled.
            self.cancel()
        assert self._task is None

        if self.data is None:
            return
        # collect all learners for which results have not yet been computed
        need_update = [(id, learner) for id, learner in self.learners.items()
                       if self.results[id] is None]
        if not need_update:
            return
        # [end-snippet-5]
        # [start-snippet-6]
        learners = [learner for _, learner in need_update]
        # setup the learner evaluations as partial function capturing
        # the necessary arguments.
        if self.testdata is None:
            learning_curve_func = partial(
                learning_curve,
                learners,
                self.data,
                folds=self.folds,
                proportions=self.curvePoints,
            )
        else:
            learning_curve_func = partial(
                learning_curve_with_test_data,
                learners,
                self.data,
                self.testdata,
                times=self.folds,
                proportions=self.curvePoints,
            )
        # [end-snippet-6]
        # [start-snippet-7]
        # setup the task state
        self._task = task = Task()
        # The learning_curve[_with_test_data] also takes a callback function
        # to report the progress. We instrument this callback to both invoke
        # the appropriate slots on this widget for reporting the progress
        # (in a thread safe manner) and to implement cooperative cancellation.
        set_progress = methodinvoke(self, "setProgressValue", (float, ))

        def callback(finished):
            # check if the task has been cancelled and raise an exception
            # from within. This 'strategy' can only be used with code that
            # properly cleans up after itself in the case of an exception
            # (does not leave any global locks, opened file descriptors, ...)
            if task.cancelled:
                raise KeyboardInterrupt()
            set_progress(finished * 100)

        # capture the callback in the partial function
        learning_curve_func = partial(learning_curve_func, callback=callback)
        # [end-snippet-7]
        # [start-snippet-8]
        self.progressBarInit()
        # Submit the evaluation function to the executor and fill in the
        # task with the resultant Future.
        task.future = self._executor.submit(learning_curve_func)
        # Setup the FutureWatcher to notify us of completion
        task.watcher = FutureWatcher(task.future)
        # by using FutureWatcher we ensure `_task_finished` slot will be
        # called from the main GUI thread by the Qt's event loop
        task.watcher.done.connect(self._task_finished)

    # [end-snippet-8]

    @pyqtSlot(float)
    def setProgressValue(self, value):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(value)

    # [start-snippet-9]
    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters
        ----------
        f : Future
            The future instance holding the result of learner evaluation.
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        try:
            results = f.result()  # type: List[Results]
        except Exception as ex:
            # Log the exception with a traceback
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occurred during evaluation: {!r}".format(ex))
            # clear all results
            for key in self.results.keys():
                self.results[key] = None
        else:
            # split the combined result into per learner/model results ...
            results = [
                list(Results.split_by_model(p_results))
                for p_results in results
            ]  # type: List[List[Results]]
            assert all(len(r.learners) == 1 for r1 in results for r in r1)
            assert len(results) == len(self.curvePoints)

            learners = [r.learners[0] for r in results[0]]
            learner_id = {
                learner: id_
                for id_, learner in self.learners.items()
            }

            # ... and update self.results
            for i, learner in enumerate(learners):
                id_ = learner_id[learner]
                self.results[id_] = [p_results[i] for p_results in results]
        # [end-snippet-9]
        # update the display
        self._update_curve_points()
        self._update_table()

    # [end-snippet-9]

    # [start-snippet-10]
    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self._task = None

    # [end-snippet-10]

    # [start-snippet-11]
    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    # [end-snippet-11]

    def _update_curve_points(self):
        for id in self.learners:
            curve = [
                self.scoring[self.scoringF][1](x)[0] for x in self.results[id]
            ]
            self.curves[id] = curve

    def _update_table(self):
        self.table.setRowCount(0)
        self.table.setRowCount(len(self.curvePoints))
        self.table.setColumnCount(len(self.learners))

        self.table.setHorizontalHeaderLabels(
            [learner.name for _, learner in self.learners.items()])
        self.table.setVerticalHeaderLabels(
            ["{:.2f}".format(p) for p in self.curvePoints])

        if self.data is None:
            return

        for column, curve in enumerate(self.curves.values()):
            for row, point in enumerate(curve):
                self.table.setItem(row, column,
                                   QTableWidgetItem("{:.5f}".format(point)))

        for i in range(len(self.learners)):
            sh = self.table.sizeHintForColumn(i)
            cwidth = self.table.columnWidth(i)
            self.table.setColumnWidth(i, max(sh, cwidth))

    def updateCurvePoints(self):
        self.curvePoints = [(x + 1.0) / self.steps for x in range(self.steps)]
Ejemplo n.º 42
0
class OWKMeans(widget.OWWidget):
    name = "k-Means"
    description = "k-Means clustering algorithm with silhouette-based " \
                  "quality estimation."
    icon = "icons/KMeans.svg"
    priority = 2100

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME,
                                Table,
                                default=True,
                                replaces=["Annotated Data"])
        centroids = Output("Centroids", Table)

    class Error(widget.OWWidget.Error):
        failed = widget.Msg("Clustering failed\nError: {}")
        not_enough_data = widget.Msg(
            "Too few ({}) unique data instances for {} clusters")

    class Warning(widget.OWWidget.Warning):
        no_silhouettes = widget.Msg(
            "Silhouette scores are not computed for >{} samples".format(
                SILHOUETTE_MAX_SAMPLES))
        not_enough_data = widget.Msg(
            "Too few ({}) unique data instances for {} clusters")

    INIT_METHODS = (("Initialize with KMeans++", "k-means++"),
                    ("Random initialization", "random"))

    resizing_enabled = False
    buttons_area_orientation = Qt.Vertical

    k = Setting(3)
    k_from = Setting(2)
    k_to = Setting(8)
    optimize_k = Setting(False)
    max_iterations = Setting(300)
    n_init = Setting(10)
    smart_init = Setting(0)  # KMeans++
    auto_commit = Setting(True)

    settings_version = 2

    @classmethod
    def migrate_settings(cls, settings, version):
        # type: (Dict, int) -> None
        if version < 2:
            if 'auto_apply' in settings:
                settings['auto_commit'] = settings.get('auto_apply', True)
                settings.pop('auto_apply', None)

    def __init__(self):
        super().__init__()

        self.data = None  # type: Optional[Table]
        self.clusterings = {}

        self.__executor = ThreadExecutor(parent=self)
        self.__task = None  # type: Optional[Task]

        layout = QGridLayout()
        bg = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "optimize_k",
            orientation=layout,
            box="Number of Clusters",
            callback=self.update_method,
        )

        layout.addWidget(
            gui.appendRadioButton(bg, "Fixed:", addToLayout=False), 1, 1)
        sb = gui.hBox(None, margin=0)
        gui.spin(sb,
                 self,
                 "k",
                 minv=2,
                 maxv=30,
                 controlWidth=60,
                 alignment=Qt.AlignRight,
                 callback=self.update_k)
        gui.rubber(sb)
        layout.addWidget(sb, 1, 2)

        layout.addWidget(gui.appendRadioButton(bg, "From", addToLayout=False),
                         2, 1)
        ftobox = gui.hBox(None)
        ftobox.layout().setContentsMargins(0, 0, 0, 0)
        layout.addWidget(ftobox, 2, 2)
        gui.spin(ftobox,
                 self,
                 "k_from",
                 minv=2,
                 maxv=29,
                 controlWidth=60,
                 alignment=Qt.AlignRight,
                 callback=self.update_from)
        gui.widgetLabel(ftobox, "to")
        gui.spin(ftobox,
                 self,
                 "k_to",
                 minv=3,
                 maxv=30,
                 controlWidth=60,
                 alignment=Qt.AlignRight,
                 callback=self.update_to)
        gui.rubber(ftobox)

        box = gui.vBox(self.controlArea, "Initialization")
        gui.comboBox(box,
                     self,
                     "smart_init",
                     items=[m[0] for m in self.INIT_METHODS],
                     callback=self.invalidate)

        layout = QGridLayout()
        gui.widgetBox(box, orientation=layout)
        layout.addWidget(gui.widgetLabel(None, "Re-runs: "), 0, 0,
                         Qt.AlignLeft)
        sb = gui.hBox(None, margin=0)
        layout.addWidget(sb, 0, 1)
        gui.lineEdit(sb,
                     self,
                     "n_init",
                     controlWidth=60,
                     valueType=int,
                     validator=QIntValidator(),
                     callback=self.invalidate)
        layout.addWidget(gui.widgetLabel(None, "Maximum iterations: "), 1, 0,
                         Qt.AlignLeft)
        sb = gui.hBox(None, margin=0)
        layout.addWidget(sb, 1, 1)
        gui.lineEdit(sb,
                     self,
                     "max_iterations",
                     controlWidth=60,
                     valueType=int,
                     validator=QIntValidator(),
                     callback=self.invalidate)

        self.apply_button = gui.auto_commit(self.buttonsArea,
                                            self,
                                            "auto_commit",
                                            "Apply",
                                            box=None,
                                            commit=self.commit)
        gui.rubber(self.controlArea)

        box = gui.vBox(self.mainArea, box="Silhouette Scores")
        self.mainArea.setVisible(self.optimize_k)
        self.table_model = ClusterTableModel(self)
        table = self.table_view = QTableView(self.mainArea)
        table.setModel(self.table_model)
        table.setSelectionMode(QTableView.SingleSelection)
        table.setSelectionBehavior(QTableView.SelectRows)
        table.setItemDelegate(gui.ColoredBarItemDelegate(self, color=Qt.cyan))
        table.selectionModel().selectionChanged.connect(self.select_row)
        table.setMaximumWidth(200)
        table.horizontalHeader().setStretchLastSection(True)
        table.horizontalHeader().hide()
        table.setShowGrid(False)
        box.layout().addWidget(table)

    def adjustSize(self):
        self.ensurePolished()
        s = self.sizeHint()
        self.resize(s)

    def update_method(self):
        self.table_model.clear_scores()
        self.commit()

    def update_k(self):
        self.optimize_k = False
        self.table_model.clear_scores()
        self.commit()

    def update_from(self):
        self.k_to = max(self.k_from + 1, self.k_to)
        self.optimize_k = True
        self.table_model.clear_scores()
        self.commit()

    def update_to(self):
        self.k_from = min(self.k_from, self.k_to - 1)
        self.optimize_k = True
        self.table_model.clear_scores()
        self.commit()

    def enough_data_instances(self, k):
        """k cannot be larger than the number of data instances."""
        return len(self.data) >= k

    @staticmethod
    def _compute_clustering(data, k, init, n_init, max_iter, silhouette,
                            random_state):
        # type: (Table, int, str, int, int, bool) -> KMeansModel
        if k > len(data):
            raise NotEnoughData()

        return KMeans(
            n_clusters=k,
            init=init,
            n_init=n_init,
            max_iter=max_iter,
            compute_silhouette_score=silhouette,
            random_state=random_state,
        )(data)

    @Slot(int, int)
    def __progress_changed(self, n, d):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        self.progressBarSet(100 * n / d)

    @Slot(int, Exception)
    def __on_exception(self, idx, ex):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None

        if isinstance(ex, NotEnoughData):
            self.Error.not_enough_data(len(self.data), self.k_from + idx)

        # Only show failed message if there is only 1 k to compute
        elif not self.optimize_k:
            self.Error.failed(str(ex))

        self.clusterings[self.k_from + idx] = str(ex)

    @Slot(int, object)
    def __clustering_complete(self, _, result):
        # type: (int, KMeansModel) -> None
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None

        self.clusterings[result.k] = result

    @Slot()
    def __commit_finished(self):
        assert QThread.currentThread() is self.thread()
        assert self.__task is not None
        assert self.data is not None

        self.__task = None
        self.setBlocking(False)
        self.progressBarFinished()

        if self.optimize_k:
            self.update_results()

        if self.optimize_k and all(
                isinstance(self.clusterings[i], str)
                for i in range(self.k_from, self.k_to + 1)):
            # Show the error of the last clustering
            self.Error.failed(self.clusterings[self.k_to])

        self.send_data()

    def __launch_tasks(self, ks):
        # type: (List[int]) -> None
        """Execute clustering in separate threads for all given ks."""
        futures = [
            self.__executor.submit(
                self._compute_clustering,
                data=self.data,
                k=k,
                init=self.INIT_METHODS[self.smart_init][1],
                n_init=self.n_init,
                max_iter=self.max_iterations,
                silhouette=True,
                random_state=RANDOM_STATE,
            ) for k in ks
        ]
        watcher = FutureSetWatcher(futures)
        watcher.resultReadyAt.connect(self.__clustering_complete)
        watcher.progressChanged.connect(self.__progress_changed)
        watcher.exceptionReadyAt.connect(self.__on_exception)
        watcher.doneAll.connect(self.__commit_finished)

        self.__task = Task(futures, watcher)
        self.progressBarInit(processEvents=False)
        self.setBlocking(True)

    def cancel(self):
        if self.__task is not None:
            task, self.__task = self.__task, None
            task.cancel()

            task.watcher.resultReadyAt.disconnect(self.__clustering_complete)
            task.watcher.progressChanged.disconnect(self.__progress_changed)
            task.watcher.exceptionReadyAt.disconnect(self.__on_exception)
            task.watcher.doneAll.disconnect(self.__commit_finished)

            self.progressBarFinished()
            self.setBlocking(False)

    def run_optimization(self):
        if not self.enough_data_instances(self.k_from):
            self.Error.not_enough_data(len(self.data), self.k_from)
            return

        if not self.enough_data_instances(self.k_to):
            self.Warning.not_enough_data(len(self.data), self.k_to)
            return

        needed_ks = [
            k for k in range(self.k_from, self.k_to + 1)
            if k not in self.clusterings
        ]

        if needed_ks:
            self.__launch_tasks(needed_ks)
        else:
            # If we don't need to recompute anything, just set the results to
            # what they were before
            self.update_results()

    def cluster(self):
        # Check if the k already has a computed clustering
        if self.k in self.clusterings:
            self.send_data()
            return

        # Check if there is enough data
        if not self.enough_data_instances(self.k):
            self.Error.not_enough_data(len(self.data), self.k)
            return

        self.__launch_tasks([self.k])

    def commit(self):
        self.cancel()
        self.clear_messages()

        # Some time may pass before the new scores are computed, so clear the
        # old scores to avoid potential confusion. Hiding the mainArea could
        # cause flickering when the clusters are computed quickly, so this is
        # the better alternative
        self.table_model.clear_scores()
        self.mainArea.setVisible(self.optimize_k and self.data is not None)

        if self.data is None:
            self.send_data()
            return

        if self.optimize_k:
            self.run_optimization()
        else:
            self.cluster()

        QTimer.singleShot(100, self.adjustSize)

    def invalidate(self):
        self.cancel()
        self.Error.clear()
        self.Warning.clear()
        self.clusterings = {}
        self.table_model.clear_scores()

        self.commit()

    def update_results(self):
        scores = [
            mk if isinstance(mk, str) else mk.silhouette
            for mk in (self.clusterings[k]
                       for k in range(self.k_from, self.k_to + 1))
        ]
        best_row = max(range(len(scores)),
                       default=0,
                       key=lambda x: 0
                       if isinstance(scores[x], str) else scores[x])
        self.table_model.set_scores(scores, self.k_from)
        self.table_view.selectRow(best_row)
        self.table_view.setFocus(Qt.OtherFocusReason)
        self.table_view.resizeRowsToContents()

    def selected_row(self):
        indices = self.table_view.selectedIndexes()
        if indices:
            return indices[0].row()

    def select_row(self):
        self.send_data()

    def send_data(self):
        if self.optimize_k:
            row = self.selected_row()
            k = self.k_from + row if row is not None else None
        else:
            k = self.k

        km = self.clusterings.get(k)
        if self.data is None or km is None or isinstance(km, str):
            self.Outputs.annotated_data.send(None)
            self.Outputs.centroids.send(None)
            return

        domain = self.data.domain
        cluster_var = DiscreteVariable(
            get_next_name(domain, "Cluster"),
            values=["C%d" % (x + 1) for x in range(km.k)])
        clust_ids = km(self.data)
        silhouette_var = ContinuousVariable(get_next_name(
            domain, "Silhouette"))
        if km.silhouette_samples is not None:
            self.Warning.no_silhouettes.clear()
            scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
        else:
            self.Warning.no_silhouettes()
            scores = np.nan

        new_domain = add_columns(domain, metas=[cluster_var, silhouette_var])
        new_table = self.data.transform(new_domain)
        new_table.get_column_view(cluster_var)[0][:] = clust_ids.X.ravel()
        new_table.get_column_view(silhouette_var)[0][:] = scores

        centroids = Table(Domain(km.pre_domain.attributes), km.centroids)

        self.Outputs.annotated_data.send(new_table)
        self.Outputs.centroids.send(centroids)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        self.data, old_data = data, self.data

        # Do not needlessly recluster the data if X hasn't changed
        if old_data and self.data and np.array_equal(self.data.X, old_data.X):
            if self.auto_commit:
                self.send_data()
        else:
            self.invalidate()

    def send_report(self):
        # False positives (Setting is not recognized as int)
        # pylint: disable=invalid-sequence-index
        if self.optimize_k and self.selected_row() is not None:
            k_clusters = self.k_from + self.selected_row()
        else:
            k_clusters = self.k
        init_method = self.INIT_METHODS[self.smart_init][0]
        init_method = init_method[0].lower() + init_method[1:]
        self.report_items(
            (("Number of clusters", k_clusters),
             ("Optimization", "{}, {} re-runs limited to {} steps".format(
                 init_method, self.n_init, self.max_iterations))))
        if self.data is not None:
            self.report_data("Data", self.data)
            if self.optimize_k:
                self.report_table(
                    "Silhouette scores for different numbers of clusters",
                    self.table_view)

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
Ejemplo n.º 43
0
class OWResolweDataObject(widget.OWWidget):
    name = "Resolwe Data Object"
    description = "Resolwe Data object viewer"
    icon = "icons/OWResolweDataObject.svg"
    priority = 20
    want_control_area = True
    want_main_area = False

    auto_commit = settings.Setting(True)

    class Inputs:
        data = widget.Input("Data", resolwe.Data)

    class Outputs:
        data = widget.Output("Data", Table)

    def __init__(self):
        super().__init__()
        self.data_table_object = None  # type: Optional[resolwe.Data]

        # threading
        self._task = None  # type: Optional[ResolweTask]
        self._executor = ThreadExecutor()

        box = gui.widgetBox(self.controlArea, 'Data Object')
        self._data_obj = QLabel(box)
        box.layout().addWidget(self._data_obj)

        box = gui.widgetBox(self.controlArea, 'Process info')
        self._proc_info = QLabel(box)
        box.layout().addWidget(self._proc_info)

        box = gui.widgetBox(self.controlArea, 'User permissions')
        self._usr_perm = QLabel(box)
        box.layout().addWidget(self._usr_perm)

        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        gui.auto_commit(self.controlArea, self, "auto_commit", "Download data")

        self.res = ResolweHelper()

    @staticmethod
    def pack_table(info):
        return '<table>\n' + "\n".join(
            '<tr><td align="right" width="120">%s:</td>\n'
            '<td width="200">%s</td></tr>\n' %
            (d, textwrap.shorten(str(v), width=100, placeholder="..."))
            for d, v in info) + "</table>\n"

    @Inputs.data
    def set_data(self, data):
        # type: (Optional[resolwe.Data]) -> None
        self.data_table_object = data

        if self.data_table_object is not None:
            self.setup()

    def handleNewSignals(self):
        self.commit()

    def __setup_data_object_info(self):
        info = self.pack_table(
            (('Id', '{}'.format(self.data_table_object.id)),
             ('Name', '{}'.format(self.data_table_object.name))))
        self._data_obj.setText(info)

    def __setup_proces_info(self):
        info = self.pack_table(
            (('Id', '{}'.format(self.data_table_object.process)),
             ('Name', '{}'.format(self.data_table_object.process_name)),
             ('Category', '{}'.format(self.data_table_object.process_type))))
        self._proc_info.setText(info)

    def __setup_usr_permissions(self):
        if self.data_table_object.current_user_permissions:
            current_usr_perm = self.data_table_object.current_user_permissions[
                0]
            perms = current_usr_perm.get('permissions', None)
            if perms:
                perms = ','.join(perms)

            info = self.pack_table(
                (('Id', '{}'.format(current_usr_perm.get('id', None))),
                 ('Name', '{}'.format(current_usr_perm.get('name', None))),
                 ('Type', '{}'.format(current_usr_perm.get('type', None))),
                 ('Permissions', '{}'.format(perms))))
            self._usr_perm.setText(info)

    def setup(self):
        self.__setup_data_object_info()
        self.__setup_proces_info()
        self.__setup_usr_permissions()

    def commit(self):
        if not self.data_table_object:
            self.Outputs.data.send(None)
            return
        self.run_task()

    def run_task(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.progressBarInit()
        func = partial(self.res.download_data_table, self.data_table_object)

        self._task = ResolweTask('download')
        self._task.future = self._executor.submit(func)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self.task_finished)

    @Slot(Future, name='Future')
    def task_finished(self, future):
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is future
        assert future.done()

        try:
            future_result = future.result()
        except Exception as ex:
            # TODO: raise exceptions
            raise ex
        else:
            if self._task.slug == 'download':
                self.Outputs.data.send(future_result)
        finally:
            self.progressBarFinished()
            self._task = None
Ejemplo n.º 44
0
class OWGeneInfo(widget.OWWidget):
    name = "Gene Info"
    description = "Displays gene information from NCBI and other sources."
    icon = "../widgets/icons/GeneInfo.svg"
    priority = 2010

    inputs = [("Data", Orange.data.Table, "setData")]
    outputs = [("Data Subset", Orange.data.Table)]

    settingsHandler = settings.DomainContextHandler()

    organism_index = settings.ContextSetting(0)
    taxid = settings.ContextSetting("9606")

    gene_attr = settings.ContextSetting(0)

    auto_commit = settings.Setting(False)
    search_string = settings.Setting("")

    useAttr = settings.ContextSetting(False)
    useAltSource = settings.ContextSetting(False)

    def __init__(self, parent=None, ):
        super().__init__(self, parent)

        self.selectionChangedFlag = False

        self.__initialized = False
        self.initfuture = None
        self.itemsfuture = None

        self.infoLabel = gui.widgetLabel(
            gui.widgetBox(self.controlArea, "Info", addSpace=True),
            "Initializing\n"
        )

        self.organisms = None
        self.organismBox = gui.widgetBox(
            self.controlArea, "Organism", addSpace=True)

        self.organismComboBox = gui.comboBox(
            self.organismBox, self, "organism_index",
            callback=self._onSelectedOrganismChanged)

        # For now only support one alt source, with a checkbox
        # In the future this can be extended to multiple selections
        self.altSourceCheck = gui.checkBox(
            self.organismBox, self, "useAltSource",
            "Show information from dictyBase",
            callback=self.onAltSourceChange)

        self.altSourceCheck.hide()

        box = gui.widgetBox(self.controlArea, "Gene names", addSpace=True)
        self.geneAttrComboBox = gui.comboBox(
            box, self, "gene_attr",
            "Gene atttibute", callback=self.updateInfoItems
        )
        self.geneAttrComboBox.setEnabled(not self.useAttr)
        cb = gui.checkBox(box, self, "useAttr", "Use attribute names",
                          callback=self.updateInfoItems)
        cb.toggled[bool].connect(self.geneAttrComboBox.setDisabled)

        gui.auto_commit(self.controlArea, self, "auto_commit", "Commit")

        # A label for dictyExpress link (Why oh god why???)
        self.dictyExpressBox = gui.widgetBox(
            self.controlArea, "Dicty Express")
        self.linkLabel = gui.widgetLabel(self.dictyExpressBox, "")
        self.linkLabel.setOpenExternalLinks(False)
        self.linkLabel.linkActivated.connect(self.onDictyExpressLink)

        self.dictyExpressBox.hide()

        gui.rubber(self.controlArea)

        gui.lineEdit(self.mainArea, self, "search_string", "Filter",
                     callbackOnType=True, callback=self.searchUpdate)

        self.treeWidget = QTreeView(
            self.mainArea,
            selectionMode=QTreeView.ExtendedSelection,
            rootIsDecorated=False,
            uniformRowHeights=True,
            sortingEnabled=True)

        self.treeWidget.setItemDelegate(
            gui.LinkStyledItemDelegate(self.treeWidget))
        self.treeWidget.viewport().setMouseTracking(True)
        self.mainArea.layout().addWidget(self.treeWidget)

        box = gui.widgetBox(self.mainArea, "", orientation="horizontal")
        gui.button(box, self, "Select Filtered", callback=self.selectFiltered)
        gui.button(box, self, "Clear Selection",
                   callback=self.treeWidget.clearSelection)

        self.geneinfo = []
        self.cells = []
        self.row2geneinfo = {}
        self.data = None

        # : (# input genes, # matches genes)
        self.matchedInfo = 0, 0

        self.setBlocking(True)
        self.executor = ThreadExecutor(self)

        self.progressBarInit()

        task = Task(
            function=partial(
                taxonomy.ensure_downloaded,
                callback=methodinvoke(self, "advance", ())
            )
        )

        task.resultReady.connect(self.initialize)
        task.exceptionReady.connect(self._onInitializeError)

        self.initfuture = self.executor.submit(task)

    def sizeHint(self):
        return QSize(1024, 720)

    @Slot()
    def advance(self):
        assert self.thread() is QThread.currentThread()
        self.progressBarSet(self.progressBarValue + 1,
                            processEvents=None)

    def initialize(self):
        if self.__initialized:
            # Already initialized
            return
        self.__initialized = True

        self.organisms = sorted(
            set([name.split(".")[-2] for name in
                 serverfiles.listfiles("NCBI_geneinfo")] +
                gene.NCBIGeneInfo.common_taxids())
        )

        self.organismComboBox.addItems(
            [taxonomy.name(tax_id) for tax_id in self.organisms]
        )
        if self.taxid in self.organisms:
            self.organism_index = self.organisms.index(self.taxid)
        else:
            self.organism_index = 0
            self.taxid = self.organisms[self.organism_index]

        self.altSourceCheck.setVisible(self.taxid == DICTY_TAXID)
        self.dictyExpressBox.setVisible(self.taxid == DICTY_TAXID)

        self.infoLabel.setText("No data on input\n")
        self.initfuture = None

        self.setBlocking(False)
        self.progressBarFinished(processEvents=None)

    def _onInitializeError(self, exc):
        sys.excepthook(type(exc), exc.args, None)
        self.error(0, "Could not download the necessary files.")

    def _onSelectedOrganismChanged(self):
        assert 0 <= self.organism_index <= len(self.organisms)
        self.taxid = self.organisms[self.organism_index]
        self.altSourceCheck.setVisible(self.taxid == DICTY_TAXID)
        self.dictyExpressBox.setVisible(self.taxid == DICTY_TAXID)

        if self.data is not None:
            self.updateInfoItems()

    def setData(self, data=None):
        if not self.__initialized:
            self.initfuture.result()
            self.initialize()

        if self.itemsfuture is not None:
            raise Exception("Already processing")

        self.closeContext()
        self.data = data

        if data is not None:
            self.geneAttrComboBox.clear()
            self.attributes = \
                [attr for attr in data.domain.variables + data.domain.metas
                 if isinstance(attr, (Orange.data.StringVariable,
                                      Orange.data.DiscreteVariable))]

            for var in self.attributes:
                self.geneAttrComboBox.addItem(*gui.attributeItem(var))

            self.taxid = data_hints.get_hint(self.data, "taxid", self.taxid)
            self.useAttr = data_hints.get_hint(
                self.data, "genesinrows", self.useAttr)

            self.openContext(data)
            self.gene_attr = min(self.gene_attr, len(self.attributes) - 1)

            if self.taxid in self.organisms:
                self.organism_index = self.organisms.index(self.taxid)
            else:
                self.organism_index = 0
                self.taxid = self.organisms[self.organism_index]

            self.updateInfoItems()
        else:
            self.clear()

    def infoSource(self):
        """ Return the current selected info source getter function from
        INFO_SOURCES
        """
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        if org not in INFO_SOURCES:
            org = "default"
        sources = INFO_SOURCES[org]
        name, func = sources[min(self.useAltSource, len(sources) - 1)]
        return name, func

    def inputGenes(self):
        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [str(ex[attr]) for ex in self.data
                     if not math.isnan(ex[attr])]
        else:
            genes = []
        return genes

    def updateInfoItems(self):
        self.warning(0)
        if self.data is None:
            return

        genes = self.inputGenes()
        if self.useAttr:
            genes = [attr.name for attr in self.data.domain.attributes]
        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            genes = [str(ex[attr]) for ex in self.data
                     if not math.isnan(ex[attr])]
        else:
            genes = []
        if not genes:
            self.warning(0, "Could not extract genes from input dataset.")

        self.warning(1)
        org = self.organisms[min(self.organism_index, len(self.organisms) - 1)]
        source_name, info_getter = self.infoSource()

        self.error(0)

        self.updateDictyExpressLink(genes, show=org == DICTY_TAXID)
        self.altSourceCheck.setVisible(org == DICTY_TAXID)

        self.progressBarInit()
        self.setBlocking(True)
        self.setEnabled(False)
        self.infoLabel.setText("Retrieving info records.\n")

        self.genes = genes

        task = Task(
            function=partial(
                info_getter, org, genes,
                advance=methodinvoke(self, "advance", ()))
        )
        self.itemsfuture = self.executor.submit(task)
        task.finished.connect(self._onItemsCompleted)

    def _onItemsCompleted(self):
        self.setBlocking(False)
        self.progressBarFinished()
        self.setEnabled(True)

        try:
            schema, geneinfo = self.itemsfuture.result()
        finally:
            self.itemsfuture = None

        self.geneinfo = geneinfo = list(zip(self.genes, geneinfo))
        self.cells = cells = []
        self.row2geneinfo = {}
        links = []
        for i, (_, gi) in enumerate(geneinfo):
            if gi:
                row = []
                for _, item in zip(schema, gi):
                    if isinstance(item, Link):
                        # TODO: This should be handled by delegates
                        row.append(item.text)
                        links.append(item.link)
                    else:
                        row.append(item)
                cells.append(row)
                self.row2geneinfo[len(cells) - 1] = i

        model = TreeModel(cells, [str(col) for col in schema], None)

        model.setColumnLinks(0, links)
        proxyModel = QSortFilterProxyModel(self)
        proxyModel.setSourceModel(model)
        self.treeWidget.setModel(proxyModel)
        self.treeWidget.selectionModel().selectionChanged.connect(self.commit)

        for i in range(7):
            self.treeWidget.resizeColumnToContents(i)
            self.treeWidget.setColumnWidth(
                i, min(self.treeWidget.columnWidth(i), 200)
            )

        self.infoLabel.setText("%i genes\n%i matched NCBI's IDs" %
                               (len(self.genes), len(cells)))
        self.matchedInfo = len(self.genes), len(cells)

    def clear(self):
        self.infoLabel.setText("No data on input\n")
        self.treeWidget.setModel(
            TreeModel([], ["NCBI ID", "Symbol", "Locus Tag",
                           "Chromosome", "Description", "Synonyms",
                           "Nomenclature"], self.treeWidget))

        self.geneAttrComboBox.clear()
        self.send("Data Subset", None)

    def commit(self):
        if self.data is None:
            self.send("Data Subset", None)
            return

        model = self.treeWidget.model()
        selection = self.treeWidget.selectionModel().selection()
        selection = model.mapSelectionToSource(selection)
        selectedRows = list(
            chain(*(range(r.top(), r.bottom() + 1) for r in selection))
        )

        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)
        gene2row = dict((self.geneinfo[self.row2geneinfo[row]][0], row)
                        for row in selectedRows)

        isselected = selectedIds.__contains__

        if self.useAttr:
            def is_selected(attr):
                return attr.name in selectedIds
            attrs = [attr for attr in self.data.domain.attributes
                     if isselected(attr.name)]
            domain = Orange.data.Domain(
                attrs, self.data.domain.class_vars, self.data.domain.metas)
            newdata = self.data.from_table(domain, self.data)
            self.send("Data Subset", newdata)

        elif self.attributes:
            attr = self.attributes[self.gene_attr]
            gene_col = [attr.str_val(v)
                        for v in self.data.get_column_view(attr)[0]]
            gene_col = [(i, name) for i, name in enumerate(gene_col)
                        if isselected(name)]
            indices = [i for i, _ in gene_col]

            # Add a gene info columns to the output
            headers = [str(model.headerData(i, Qt.Horizontal, Qt.DisplayRole))
                       for i in range(model.columnCount())]
            metas = [Orange.data.StringVariable(name) for name in headers]
            domain = Orange.data.Domain(
                self.data.domain.attributes, self.data.domain.class_vars,
                self.data.domain.metas + tuple(metas))

            newdata = self.data.from_table(domain, self.data)[indices]

            model_rows = [gene2row[gene] for _, gene in gene_col]
            for col, meta in zip(range(model.columnCount()), metas):
                col_data = [str(model.index(row, col).data(Qt.DisplayRole))
                            for row in model_rows]
                newdata[:, meta] = col_data

            if not len(newdata):
                newdata = None

            self.send("Data Subset", newdata)
        else:
            self.send("Data Subset", None)

    def rowFiltered(self, row):
        searchStrings = self.search_string.lower().split()
        row = " ".join(self.cells[row]).lower()
        return not all([s in row for s in searchStrings])

    def searchUpdate(self):
        if not self.data:
            return
        searchStrings = self.search_string.lower().split()
        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            row = " ".join(row).lower()
            self.treeWidget.setRowHidden(
                mapFromSource(index(i, 0)).row(),
                QModelIndex(),
                not all([s in row for s in searchStrings]))

    def selectFiltered(self):
        if not self.data:
            return
        itemSelection = QItemSelection()

        index = self.treeWidget.model().sourceModel().index
        mapFromSource = self.treeWidget.model().mapFromSource
        for i, row in enumerate(self.cells):
            if not self.rowFiltered(i):
                itemSelection.select(mapFromSource(index(i, 0)),
                                     mapFromSource(index(i, 0)))
        self.treeWidget.selectionModel().select(
            itemSelection,
            QItemSelectionModel.Select | QItemSelectionModel.Rows)

    def updateDictyExpressLink(self, genes, show=False):
        def fix(ddb):
            if ddb.startswith("DDB"):
                if not ddb.startswith("DDB_G"):
                    ddb = ddb.replace("DDB", "DDB_G")
                return ddb
            return None
        if show:
            genes = [fix(gene) for gene in genes if fix(gene)]
            link1 = '<a href="http://dictyexpress.biolab.si/run/index.php?gene=%s">Microarray profile</a>'
            link2 = '<a href="http://dictyexpress.biolab.si/run/index.php?gene=%s&db=rnaseq">RNA-Seq profile</a>'
            self.linkLabel.setText(link1 + "<br/>" + link2)

            show = any(genes)

        if show:
            self.dictyExpressBox.show()
        else:
            self.dictyExpressBox.hide()

    def onDictyExpressLink(self, link):
        if not self.data:
            return

        selectedIndexes = self.treeWidget.selectedIndexes()
        if not len(selectedIndexes):
            QMessageBox.information(
                self, "No gene ids selected",
                "Please select some genes and try again."
            )
            return
        model = self.treeWidget.model()
        mapToSource = model.mapToSource
        selectedRows = self.treeWidget.selectedIndexes()
        selectedRows = [mapToSource(index).row() for index in selectedRows]
        model = model.sourceModel()

        selectedGeneids = [self.row2geneinfo[row] for row in selectedRows]
        selectedIds = [self.geneinfo[i][0] for i in selectedGeneids]
        selectedIds = set(selectedIds)

        def fix(ddb):
            if ddb.startswith("DDB"):
                if not ddb.startswith("DDB_G"):
                    ddb = ddb.replace("DDB", "DDB_G")
                return ddb
            return None

        genes = [fix(gene) for gene in selectedIds if fix(gene)]
        url = str(link) % " ".join(genes)
        QDesktopServices.openUrl(QUrl(url))

    def onAltSourceChange(self):
        self.updateInfoItems()

    def onDeleteWidget(self):
        # try to cancel pending tasks
        if self.initfuture:
            self.initfuture.cancel()
        if self.itemsfuture:
            self.itemsfuture.cancel()

        self.executor.shutdown(wait=False)
        super().onDeleteWidget()
class OWExplainPredictions(OWWidget):

    name = "Explain Predictions"
    description = "Computes attribute contributions to the final prediction with an approximation algorithm for shapely value"
    icon = "icons/ExplainPredictions.svg"
    priority = 200
    gui_error = settings.Setting(0.05)
    gui_p_val = settings.Setting(0.05)
    gui_num_atr = settings.Setting(20)
    sort_index = settings.Setting(SortBy.ABSOLUTE)

    class Inputs:
        data = Input("Data", Table, default=True)
        model = Input("Model", Model, multiple=False)
        sample = Input("Sample", Table)

    class Outputs:
        explanations = Output("Explanations", Table)

    class Error(OWWidget.Error):
        sample_too_big = widget.Msg("Can only explain one sample at the time.")

    class Warning(OWWidget.Warning):
        unknowns_increased = widget.Msg(
            "Number of unknown values increased, Data and Sample domains mismatch.")

    def __init__(self):
        super().__init__()
        self.data = None
        self.model = None
        self.to_explain = None
        self.explanations = None
        self.stop = True
        self.e = None

        self._task = None
        self._executor = ThreadExecutor()

        info_box = gui.vBox(self.controlArea, "Info")
        self.data_info = gui.widgetLabel(info_box, "Data: N/A")
        self.model_info = gui.widgetLabel(info_box, "Model: N/A")
        self.sample_info = gui.widgetLabel(info_box, "Sample: N/A")

        criteria_box = gui.vBox(self.controlArea, "Stopping criteria")
        self.error_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_error",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error < ",
                                   spinType=float,
                                   callback=self._update_error_spin,
                                   controlWidth=80,
                                   keyboardTracking=False)

        self.p_val_spin = gui.spin(criteria_box,
                                   self,
                                   "gui_p_val",
                                   0.01,
                                   1,
                                   step=0.01,
                                   label="Error p-value < ",
                                   spinType=float,
                                   callback=self._update_p_val_spin,
                                   controlWidth=80, keyboardTracking=False)

        plot_properties_box = gui.vBox(self.controlArea, "Display features")
        self.num_atr_spin = gui.spin(plot_properties_box,
                                     self,
                                     "gui_num_atr",
                                     1,
                                     100,
                                     step=1,
                                     label="Show attributes",
                                     callback=self._update_num_atr_spin,
                                     controlWidth=80,
                                     keyboardTracking=False)

        self.sort_combo = gui.comboBox(plot_properties_box,
                                       self,
                                       "sort_index",
                                       label="Rank by",
                                       items=SortBy.items(),
                                       orientation=Qt.Horizontal,
                                       callback=self._update_combo)

        gui.rubber(self.controlArea)

        self.cancel_button = gui.button(self.controlArea,
                                        self,
                                        "Stop Computation",
                                        callback=self.toggle_button,
                                        autoDefault=True,
                                        tooltip="Stops and restarts computation")
        self.cancel_button.setDisabled(True)

        predictions_box = gui.vBox(self.mainArea, "Model prediction")
        self.predict_info = gui.widgetLabel(predictions_box, "")

        self.mainArea.setMinimumWidth(700)
        self.resize(700, 400)

        class _GraphicsView(QGraphicsView):
            def __init__(self, scene, parent, **kwargs):
                for k, v in dict(verticalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOff,
                                 viewportUpdateMode=QGraphicsView.BoundingRectViewportUpdate,
                                 renderHints=(QPainter.Antialiasing |
                                              QPainter.TextAntialiasing |
                                              QPainter.SmoothPixmapTransform),
                                 alignment=(Qt.AlignTop |
                                            Qt.AlignLeft),
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.MinimumExpanding)).items():
                    kwargs.setdefault(k, v)
                super().__init__(scene, parent, **kwargs)

        class GraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
                                 styleSheet='QGraphicsView {background: white}')
                self.viewport().setMinimumWidth(500)
                self._is_resizing = False

            w = self

            def resizeEvent(self, resizeEvent):
                self._is_resizing = True
                self.w.draw()
                self._is_resizing = False
                return super().resizeEvent(resizeEvent)

            def is_resizing(self):
                return self._is_resizing

            def sizeHint(self):
                return QSize(600, 300)

        class FixedSizeGraphicsView(_GraphicsView):
            def __init__(self, scene, parent):
                super().__init__(scene, parent,
                                 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding,
                                                        QSizePolicy.Minimum))

            def sizeHint(self):
                return QSize(600, 30)

        """all will share the same scene, but will show different parts of it"""
        self.box_scene = QGraphicsScene(self)

        self.box_view = GraphicsView(self.box_scene, self)
        self.header_view = FixedSizeGraphicsView(self.box_scene, self)
        self.footer_view = FixedSizeGraphicsView(self.box_scene, self)

        self.mainArea.layout().addWidget(self.header_view)
        self.mainArea.layout().addWidget(self.box_view)
        self.mainArea.layout().addWidget(self.footer_view)

        self.painter = None

    def draw(self):
        """Uses GraphAttributes class to draw the explanaitons """
        self.box_scene.clear()
        wp = self.box_view.viewport().rect()
        header_height = 30
        if self.explanations is not None:
            self.painter = GraphAttributes(self.box_scene, min(
                self.gui_num_atr, self.explanations.Y.shape[0]))
            self.painter.paint(wp, self.explanations, header_h=header_height)

        """set appropriate boxes for different views"""
        rect = QRectF(self.box_scene.itemsBoundingRect().x(),
                      self.box_scene.itemsBoundingRect().y(),
                      self.box_scene.itemsBoundingRect().width(),
                      self.box_scene.itemsBoundingRect().height())

        self.box_scene.setSceneRect(rect)
        self.box_view.setSceneRect(
            rect.x(), rect.y()+header_height+2, rect.width(), rect.height() - 80)
        self.header_view.setSceneRect(
            rect.x(), rect.y(), rect.width(), 10)
        self.header_view.setFixedHeight(header_height)
        self.footer_view.setSceneRect(
            rect.x(), rect.y() + rect.height() - 50, rect.width(), 35)

    def sort_explanations(self):
        """sorts explanations according to users choice from combo box"""
        if self.sort_index == SortBy.POSITIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])][::-1]
        elif self.sort_index == SortBy.NEGATIVE:
            self.explanations = self.explanations[np.argsort(
                self.explanations.X[:, 0])]
        elif self.sort_index == SortBy.ABSOLUTE:
            self.explanations = self.explanations[np.argsort(
                np.abs(self.explanations.X[:, 0]))][::-1]
        elif self.sort_index == SortBy.BY_NAME:
            l = np.array(
                list(map(np.chararray.lower, self.explanations.metas[:, 0])))
            self.explanations = self.explanations[np.argsort(l)]
        else:
            return

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        """Set input 'Data"""
        self.data = data
        self.explanations = None
        self.data_info.setText("Data: N/A")
        self.e = None
        if data is not None:
            model = TableModel(data, parent=None)
            if data.X.shape[0] == 1:
                inst = "1 instance and "
            else:
                inst = str(data.X.shape[0]) + " instances and "
            if data.X.shape[1] == 1:
                feat = "1 feature "
            else:
                feat = str(data.X.shape[1]) + " features"
            self.data_info.setText("Data: " + inst + feat)

    @Inputs.model
    def set_predictor(self, model):
        """Set input 'Model"""
        self.model = model
        self.model_info.setText("Model: N/A")
        self.explanations = None
        self.e = None
        if model is not None:
            self.model_info.setText("Model: " + str(model.name))

    @Inputs.sample
    @check_sql_input
    def set_sample(self, sample):
        """Set input 'Sample', checks if size is appropriate"""
        self.to_explain = sample
        self.explanations = None
        self.Error.sample_too_big.clear()
        self.sample_info.setText("Sample: N/A")
        if sample is not None:
            if len(sample.X) != 1:
                self.to_explain = None
                self.Error.sample_too_big()
            else:
                if sample.X.shape[1] == 1:
                    feat = "1 feature"
                else:
                    feat = str(sample.X.shape[1]) + " features"
                self.sample_info.setText("Sample: " + feat)
                if self.e is not None:
                    self.e.saved = False

    def handleNewSignals(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self.predict_info.setText("")
        self.Warning.unknowns_increased.clear()
        self.stop = True
        self.cancel_button.setText("Stop Computation")
        self.commit_calc_or_output()

    def commit_calc_or_output(self):
        if self.data is not None and self.to_explain is not None:
            self.commit_calc()
        else:
            self.commit_output()

    def commit_calc(self):
        num_nan = np.count_nonzero(np.isnan(self.to_explain.X[0]))

        self.to_explain = self.to_explain.transform(self.data.domain)
        if num_nan != np.count_nonzero(np.isnan(self.to_explain.X[0])):
            self.Warning.unknowns_increased()
        if self.model is not None:
            # calculate contributions
            if self.e is None:
                self.e = ExplainPredictions(self.data,
                                            self.model,
                                            batch_size=min(
                                                len(self.data.X), 500),
                                            p_val=self.gui_p_val,
                                            error=self.gui_error)
            self._task = task = Task()

            def callback(progress):
                nonlocal task
                # update progress bar
                QMetaObject.invokeMethod(
                    self, "set_progress_value", Qt.QueuedConnection, Q_ARG(int, progress))
                if task.canceled:
                    return True
                return False

            def callback_update(table):
                QMetaObject.invokeMethod(
                    self, "update_view", Qt.QueuedConnection, Q_ARG(Orange.data.Table, table))

            def callback_prediction(class_value):
                QMetaObject.invokeMethod(
                    self, "update_model_prediction", Qt.QueuedConnection, Q_ARG(float, class_value))

            self.was_canceled = False
            explain_func = partial(
                self.e.anytime_explain, self.to_explain[0], callback=callback, update_func=callback_update, update_prediction=callback_prediction)

            self.progressBarInit(processEvents=None)
            task.future = self._executor.submit(explain_func)
            task.watcher = FutureWatcher(task.future)
            task.watcher.done.connect(self._task_finished)
            self.cancel_button.setDisabled(False)

    @pyqtSlot(Orange.data.Table)
    def update_view(self, table):
        self.explanations = table
        self.sort_explanations()
        self.draw()
        self.commit_output()

    @pyqtSlot(float)
    def update_model_prediction(self, value):
        self._print_prediction(value)

    @pyqtSlot(int)
    def set_progress_value(self, value):
        self.progressBarSet(value, processEvents=False)

    @pyqtSlot(concurrent.futures.Future)
    def _task_finished(self, f):
        """
        Parameters:
        ----------
        f: conncurent.futures.Future
            future instance holding the result of learner evaluation
        """
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None

        if not self.was_canceled:
            self.cancel_button.setDisabled(True)

        try:
            results = f.result()
        except Exception as ex:
            log = logging.getLogger()
            log.exception(__name__, exc_info=True)
            self.error("Exception occured during evaluation: {!r}".format(ex))

            for key in self.results.keys():
                self.results[key] = None
        else:
            self.update_view(results[1])

        self.progressBarFinished(processEvents=False)

    def commit_output(self):
        """
        Sends best-so-far results forward
        """
        self.Outputs.explanations.send(self.explanations)

    def toggle_button(self):
        if self.stop:
            self.stop = False
            self.cancel_button.setText("Restart Computation")
            self.cancel()
        else:
            self.stop = True
            self.cancel_button.setText("Stop Computation")
            self.commit_calc_or_output()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._task_finished)
            self.was_canceled = True
            self._task_finished(self._task.future)

    def _print_prediction(self, class_value):
        """
        Parameters
        ----------
        class_value: float 
            Number representing either index of predicted class value, looked up in domain, or predicted value (regression)
        """
        name = self.data.domain.class_vars[0].name
        if isinstance(self.data.domain.class_vars[0], ContinuousVariable):
            self.predict_info.setText(name + ":      " + str(class_value))
        else:
            self.predict_info.setText(
                name + ":      " + self.data.domain.class_vars[0].values[int(class_value)])

    def _update_error_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.error = self.gui_error
        self.handleNewSignals()

    def _update_p_val_spin(self):
        self.cancel()
        if self.e is not None:
            self.e.p_val = self.gui_p_val
        self.handleNewSignals()

    def _update_num_atr_spin(self):
        self.cancel()
        self.handleNewSignals()

    def _update_combo(self):
        if self.explanations != None:
            self.sort_explanations()
            self.draw()
            self.commit_output()

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()