Beispiel #1
0
class OWDotMatrix(widget.OWWidget):
    name = "Dot Matrix"
    description = "Perform cluster analysis."
    icon = "icons/DotMatrix.svg"
    priority = 410

    class Inputs:
        data = Input("Data", Table, default=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        contingency = Output("Contingency Table", Table)

    GENE_MAXIMUM = 100
    CELL_SIZES = (14, 22, 30)
    AGGREGATE_F = [
        lambda x: np.mean(x, axis=0),
        lambda x: np.median(x, axis=0),
        lambda x: np.min(x, axis=0),
        lambda x: np.max(x, axis=0),
        lambda x: np.mean(x > 0, axis=0),
    ]
    AGGREGATE_NAME = [
        "Mean expression", "Median expression", "Min expression",
        "Max expression", "Fraction expressing"
    ]

    settingsHandler = DomainContextHandler(metas_in_res=True)
    cluster_var = ContextSetting(None)
    aggregate_ix = ContextSetting(0)  # type: int
    biclustering = ContextSetting(True)
    transpose = ContextSetting(False)
    log_scale = ContextSetting(False)
    normalize = ContextSetting(True)
    cell_size_ix = ContextSetting(2)  # type: int
    selection = ContextSetting(set())
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.data = None  # type: Table
        self.matrix = None
        self.clusters = None
        self.cluster_order = None
        self.genes = None
        self.gene_order = None
        self.rows = None
        self.columns = None
        self.feature_model = DomainModel(valid_types=DiscreteVariable)

        box = gui.vBox(self.controlArea, "Info")
        self.infobox = gui.widgetLabel(box, self._get_info_string())

        box = gui.vBox(self.controlArea, "Cluster Variable")
        gui.comboBox(box,
                     self,
                     "cluster_var",
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._calculate_table_values)

        box = gui.vBox(self.controlArea, "Aggregation")
        gui.comboBox(box,
                     self,
                     "aggregate_ix",
                     sendSelectedValue=False,
                     items=self.AGGREGATE_NAME,
                     callback=self._calculate_table_values)

        box = gui.vBox(self.controlArea, "Options")
        gui.checkBox(box,
                     self,
                     "biclustering",
                     "Biclustering of cells and genes",
                     callback=self._calculate_table_values)
        gui.checkBox(box,
                     self,
                     "transpose",
                     "Transpose",
                     callback=self._refresh_table)
        gui.checkBox(box,
                     self,
                     "log_scale",
                     "Log scale",
                     callback=self._refresh_table)
        gui.checkBox(box,
                     self,
                     "normalize",
                     "Normalize data",
                     callback=self._refresh_table)

        box = gui.vBox(self.controlArea, "Plot Size")
        gui.radioButtons(box,
                         self,
                         "cell_size_ix",
                         btnLabels=("S", "M", "L"),
                         callback=lambda: self.tableview.set_cell_size(
                             self.CELL_SIZES[self.cell_size_ix]),
                         orientation=Qt.Horizontal)

        gui.rubber(self.controlArea)

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    def _get_info_string(self):
        formatstr = "{} genes, {} cells\n{} clusters"
        if self.data:
            return formatstr.format(len(self.data.domain.attributes),
                                    len(self.data), len(self.clusters))
        else:
            return formatstr.format(*([0] * 3))

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()

        self.data = data
        self.matrix = None
        self.feature_model.set_domain(None)
        self.cluster_var = None
        self.clusters = None
        self.cluster_order = None
        self.genes = None
        self.gene_order = None
        self.rows = None
        self.columns = None

        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.Error.clear()
                self.openContext(self.data)
                if self.cluster_var is None:
                    self.cluster_var = self.feature_model[0]
                self._calculate_table_values()
            else:
                self.tableview.clear()
                self.error("No discrete variables in data.")
                self.data = None
        else:
            self.tableview.clear()
            self.Error.clear()

    @staticmethod
    def _group_by(table: Table, var: DiscreteVariable):
        column = table.get_column_view(var)[0]
        return (table[column == value] for value in np.unique(column))

    def _calculate_table_values(self):
        if self.data is None:
            self.Warning.clear()
        else:
            self.clusters = [
                self.cluster_var.values[int(ix)] for ix in np.unique(
                    self.data.get_column_view(self.cluster_var)[0])
            ]
            self.genes = [var.name for var in self.data.domain.attributes]
            self.infobox.setText(self._get_info_string())

            if len(self.genes) > 100:
                self.warning(
                    "Too many genes on input, first {} genes displayed.".
                    format(self.GENE_MAXIMUM))
            else:
                self.Warning.clear()

            self.matrix = np.stack(
                (self.AGGREGATE_F[self.aggregate_ix](
                    cluster.X[:self.GENE_MAXIMUM])
                 for cluster in self._group_by(self.data, self.cluster_var)),
                axis=0)

            if self.biclustering:
                self.cluster_order, self.gene_order = ClusterAnalysis.biclustering(
                    self.matrix, ClusterAnalysis.neighbor_distance)
            else:
                self.cluster_order, self.gene_order = np.arange(
                    len(self.clusters)), np.arange(len(self.genes))
            self.matrix = self.matrix[self.cluster_order][:, self.gene_order]

            self._refresh_table()
            self._invalidate()

    def _refresh_table(self):
        if self.matrix is not None:
            if not self.transpose:
                self.rows, self.columns = self.clusters, self.genes
                row_order, column_order = self.cluster_order, self.gene_order
            else:
                self.rows, self.columns = self.genes, self.clusters
                row_order, column_order = self.gene_order, self.cluster_order
            self.tableview.set_headers(
                np.array(self.rows)[row_order],
                np.array(self.columns)[column_order],
                circles=True,
                cell_size=self.CELL_SIZES[self.cell_size_ix],
                bold_headers=False)
            if self.matrix.size > 0:
                matrix = self.matrix
                if self.log_scale:
                    matrix = np.log(matrix + 1)
                if self.normalize:
                    matrix = (matrix - np.mean(matrix, axis=0, keepdims=True)
                              ) / np.std(matrix, axis=0, keepdims=True)
                    matrix[matrix < -3] = -3
                    matrix[matrix > 3] = 3
                    matrix = matrix - matrix.min(axis=0, keepdims=True)
                    matrix = matrix / matrix.max(axis=0, keepdims=True)
                else:
                    matrix = matrix - matrix.min()
                    matrix = matrix / matrix.max()
                if self.transpose:
                    matrix = matrix.T

                def tooltip(i, j):
                    if not self.transpose:
                        cluster, gene, value = self.clusters[i], self.genes[
                            j], self.matrix[i, j]
                    else:
                        cluster, gene, value = self.clusters[j], self.genes[
                            i], self.matrix[j, i]
                    return "Cluster: {}\nGene: {}\n{}: {:.1f}".format(
                        cluster, gene, self.AGGREGATE_NAME[self.aggregate_ix],
                        value)

                self.tableview.update_table(matrix, tooltip=tooltip)

    def commit(self):
        if len(self.selection):
            cluster_ids = set()
            gene_ids = set()
            for (ir, ic) in self.selection:
                if not self.transpose:
                    cluster_ids.add(ir)
                    gene_ids.add(ic)
                else:
                    cluster_ids.add(ic)
                    gene_ids.add(ir)

            new_domain = Domain(
                [self.data.domain[self.genes[i]] for i in gene_ids],
                self.data.domain.class_vars, self.data.domain.metas)
            selected_data = Values([
                FilterDiscrete(self.cluster_var, [self.clusters[i]])
                for i in cluster_ids
            ],
                                   conjunction=False)(self.data)
            selected_data = selected_data.transform(new_domain)
            annotated_data = create_annotated_table(
                self.data.transform(new_domain),
                np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])
        if self.matrix is not None:
            table = ClusterAnalysis.contingency_table(
                self.matrix,
                DiscreteVariable(self.cluster_var.name,
                                 np.array(self.clusters)),
                np.array(self.genes)[self.gene_order],
                self.cluster_order[..., np.newaxis])
        else:
            table = None
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.contingency.send(table)

    def _invalidate(self):
        self.selection = self.tableview.get_selection()
        self.commit()
class OWClusterAnalysis(widget.OWWidget):
    name = "Cluster Analysis"
    description = "Perform cluster analysis."
    icon = "icons/ClusterAnalysis.svg"
    priority = 2010

    class Inputs:
        data = Input("Data", Table, default=True)
        genes = Input("Genes", Table)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        contingency = Output("Contingency Table", Table)

    N_GENES_PER_CLUSTER_MAX = 10
    N_MOST_ENRICHED_MAX = 50
    CELL_SIZES = (14, 22, 30)

    settingsHandler = DomainContextHandler(metas_in_res=True)
    cluster_var = ContextSetting(None)
    selection = ContextSetting(set())
    gene_selection = ContextSetting(0)
    differential_expression = ContextSetting(0)
    cell_size_ix = ContextSetting(2)
    _diff_exprs = ("high", "low", "either")
    n_genes_per_cluster = ContextSetting(3)
    n_most_enriched = ContextSetting(20)
    biclustering = ContextSetting(True)
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.ca = None
        self.clusters = None
        self.data = None
        self.feature_model = DomainModel(valid_types=DiscreteVariable)
        self.gene_list = None
        self.model = None
        self.pvalues = None

        self._executor = ThreadExecutor()
        self._gene_selection_history = (self.gene_selection,
                                        self.gene_selection)
        self._task = None

        box = gui.vBox(self.controlArea, "Info")
        self.infobox = gui.widgetLabel(box, self._get_info_string())

        box = gui.vBox(self.controlArea, "Cluster Variable")
        gui.comboBox(box,
                     self,
                     "cluster_var",
                     sendSelectedValue=True,
                     model=self.feature_model,
                     callback=self._run_cluster_analysis)

        layout = QGridLayout()
        self.gene_selection_radio_group = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "gene_selection",
            orientation=layout,
            box="Gene Selection",
            callback=self._gene_selection_changed)

        def conditional_set_gene_selection(id):
            def f():
                if self.gene_selection == id:
                    return self._set_gene_selection()

            return f

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False), 1, 1)
        cb = gui.hBox(None, margin=0)
        gui.widgetLabel(cb, "Top")
        self.n_genes_per_cluster_spin = gui.spin(
            cb,
            self,
            "n_genes_per_cluster",
            minv=1,
            maxv=self.N_GENES_PER_CLUSTER_MAX,
            controlWidth=60,
            alignment=Qt.AlignRight,
            callback=conditional_set_gene_selection(0))
        gui.widgetLabel(cb, "genes per cluster")
        gui.rubber(cb)
        layout.addWidget(cb, 1, 2, Qt.AlignLeft)

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False), 2, 1)
        mb = gui.hBox(None, margin=0)
        gui.widgetLabel(mb, "Top")
        self.n_most_enriched_spin = gui.spin(
            mb,
            self,
            "n_most_enriched",
            minv=1,
            maxv=self.N_MOST_ENRICHED_MAX,
            controlWidth=60,
            alignment=Qt.AlignRight,
            callback=conditional_set_gene_selection(1))
        gui.widgetLabel(mb, "highest enrichments")
        gui.rubber(mb)
        layout.addWidget(mb, 2, 2, Qt.AlignLeft)

        layout.addWidget(
            gui.appendRadioButton(self.gene_selection_radio_group,
                                  "",
                                  addToLayout=False,
                                  disabled=True), 3, 1)
        sb = gui.hBox(None, margin=0)
        gui.widgetLabel(sb, "User-provided list of genes")
        gui.rubber(sb)
        layout.addWidget(sb, 3, 2)

        layout = QGridLayout()
        self.differential_expression_radio_group = gui.radioButtonsInBox(
            self.controlArea,
            self,
            "differential_expression",
            orientation=layout,
            box="Differential Expression",
            callback=self._set_gene_selection)

        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Overexpressed in cluster",
                                  addToLayout=False), 1, 1)
        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Underexpressed in cluster",
                                  addToLayout=False), 2, 1)
        layout.addWidget(
            gui.appendRadioButton(self.differential_expression_radio_group,
                                  "Either",
                                  addToLayout=False), 3, 1)

        box = gui.vBox(self.controlArea, "Sorting and Zoom")
        gui.checkBox(box,
                     self,
                     "biclustering",
                     "Biclustering of analysis results",
                     callback=self._set_gene_selection)
        gui.radioButtons(box,
                         self,
                         "cell_size_ix",
                         btnLabels=("S", "M", "L"),
                         callback=lambda: self.tableview.set_cell_size(
                             self.CELL_SIZES[self.cell_size_ix]),
                         orientation=Qt.Horizontal)

        gui.rubber(self.controlArea)

        self.apply_button = gui.auto_commit(self.controlArea,
                                            self,
                                            "auto_apply",
                                            "&Apply",
                                            box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    def _get_current_gene_selection(self):
        return self._gene_selection_history[0]

    def _get_previous_gene_selection(self):
        return self._gene_selection_history[1]

    def _progress_gene_selection_history(self, new_gene_selection):
        self._gene_selection_history = (new_gene_selection,
                                        self._gene_selection_history[0])

    def _get_info_string(self):
        formatstr = "Cells: {0}\nGenes: {1}\nClusters: {2}"
        if self.data:
            return formatstr.format(len(self.data),
                                    len(self.data.domain.attributes),
                                    len(self.cluster_var.values))
        else:
            return formatstr.format(*["No input data"] * 3)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()
        self.data = data
        self.feature_model.set_domain(None)
        self.ca = None
        self.cluster_var = None
        self.columns = None
        self.clusters = None
        self.gene_list = None
        self.model = None
        self.pvalues = None
        self.n_genes_per_cluster_spin.setMaximum(self.N_GENES_PER_CLUSTER_MAX)
        self.n_most_enriched_spin.setMaximum(self.N_MOST_ENRICHED_MAX)
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.openContext(self.data)
                if self.cluster_var is None:
                    self.cluster_var = self.feature_model[0]
                self._run_cluster_analysis()
            else:
                self.tableview.clear()
        else:
            self.tableview.clear()

    @Inputs.genes
    def set_genes(self, data):
        self.Error.clear()
        gene_list_radio = self.gene_selection_radio_group.group.buttons()[2]

        if (data is None or GENE_AS_ATTRIBUTE_NAME not in data.attributes
                or not data.attributes[GENE_AS_ATTRIBUTE_NAME]
                and GENE_ID_COLUMN not in data.attributes
                or data.attributes[GENE_AS_ATTRIBUTE_NAME]
                and GENE_ID_ATTRIBUTE not in data.attributes):
            if data is not None:
                self.error(
                    "Gene annotations missing in the input data. Use Gene Name Matching widget."
                )
            self.gene_list = None
            gene_list_radio.setDisabled(True)
            if self.gene_selection == 2:
                self.gene_selection_radio_group.group.buttons()[
                    self._get_previous_gene_selection()].click()
        else:
            if data.attributes[GENE_AS_ATTRIBUTE_NAME]:
                gene_id_attribute = data.attributes.get(
                    GENE_ID_ATTRIBUTE, None)

                self.gene_list = tuple(
                    str(var.attributes[gene_id_attribute])
                    for var in data.domain.attributes
                    if gene_id_attribute in var.attributes
                    and var.attributes[gene_id_attribute] != "?")
            else:
                gene_id_column = data.attributes.get(GENE_ID_COLUMN, None)
                self.gene_list = tuple(
                    str(v) for v in data.get_column_view(gene_id_column)[0]
                    if v not in ("", "?"))
            gene_list_radio.setDisabled(False)
            if self.gene_selection == 2:
                self._set_gene_selection()
            else:
                gene_list_radio.click()

    def _run_cluster_analysis(self):
        self.infobox.setText(self._get_info_string())
        gene_count = len(self.data.domain.attributes)
        cluster_count = len(self.cluster_var.values)
        self.n_genes_per_cluster_spin.setMaximum(
            min(self.N_GENES_PER_CLUSTER_MAX, gene_count // cluster_count))
        self.n_most_enriched_spin.setMaximum(
            min(self.N_MOST_ENRICHED_MAX, gene_count))
        # TODO: what happens if error occurs? If CA fails, widget should properly handle it.
        self._start_task_init(
            partial(ClusterAnalysis, self.data, self.cluster_var.name))

    def _start_task_init(self, f):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task("init")

        def callback(finished):
            if self._task.cancelled:
                raise KeyboardInterrupt()
            self.progressBarSet(finished * 50)

        f = partial(f, callback=callback)

        self.progressBarInit()
        self._task.future = self._executor.submit(f)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._init_task_finished)

    def _start_task_gene_selection(self, f):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task("gene_selection")

        def callback(finished):
            if self._task.cancelled:
                raise KeyboardInterrupt()
            self.progressBarSet(50 + finished * 50)

        f = partial(f, callback=callback)

        self.progressBarInit()
        self.progressBarSet(50)
        self._task.future = self._executor.submit(f)
        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._gene_selection_task_finished)

    @Slot(concurrent.futures.Future)
    def _init_task_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        self.ca = f.result()
        self._set_gene_selection()

    @Slot(concurrent.futures.Future)
    def _gene_selection_task_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progressBarFinished()

        self.clusters, genes, self.model, self.pvalues = f.result()
        genes = [str(gene) for gene in genes]
        self.columns = DiscreteVariable("Gene", genes, ordered=True)
        self.tableview.set_headers(
            self.clusters,
            self.columns.values,
            circles=True,
            cell_size=self.CELL_SIZES[self.cell_size_ix],
            bold_headers=False)

        def tooltip(i, j):
            return (
                "<b>cluster</b>: {}<br /><b>gene</b>: {}<br /><b>fraction expressing</b>: {:.2f}<br />\
                                <b>p-value</b>: {:.2e}".format(
                    self.clusters[i], self.columns.values[j], self.model[i, j],
                    self.pvalues[i, j]))

        self.tableview.update_table(self.model, tooltip=tooltip)
        self._invalidate()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            if self._task.type == "init":
                self._task.watcher.done.disconnect(self._init_task_finished)
            else:
                self._task.watcher.done.disconnect(
                    self._gene_selection_task_finished)
            self._task = None

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()

    def _gene_selection_changed(self):
        if self.gene_selection != self._get_current_gene_selection():
            self._progress_gene_selection_history(self.gene_selection)
            self.differential_expression_radio_group.setDisabled(
                self.gene_selection == 2)
            self._set_gene_selection()

    def _set_gene_selection(self):
        self.Warning.clear()
        if self.ca is not None and (self._task is None
                                    or self._task.type != "init"):
            if self.gene_selection == 0:
                f = partial(self.ca.enriched_genes_per_cluster,
                            self.n_genes_per_cluster)
            elif self.gene_selection == 1:
                f = partial(self.ca.enriched_genes_data, self.n_most_enriched)
            else:
                if self.data is not None and GENE_ID_ATTRIBUTE not in self.data.attributes:
                    self.error(
                        "Gene annotations missing in the input data. Use Gene Name Matching widget."
                    )
                    if self.gene_selection == 2:
                        self.gene_selection_radio_group.group.buttons()[
                            self._get_previous_gene_selection()].click()
                    return
                relevant_genes = tuple(self.ca.intersection(self.gene_list))
                if len(relevant_genes) > self.N_MOST_ENRICHED_MAX:
                    self.warning("Only first {} reference genes shown.".format(
                        self.N_MOST_ENRICHED_MAX))
                f = partial(self.ca.enriched_genes,
                            relevant_genes[:self.N_MOST_ENRICHED_MAX])
            f = partial(
                f,
                enrichment=self._diff_exprs[self.differential_expression],
                biclustering=self.biclustering)
            self._start_task_gene_selection(f)
        else:
            self._invalidate()

    def handleNewSignals(self):
        self._invalidate()

    def commit(self):
        if len(self.selection):
            cluster_ids = set()
            column_ids = set()
            for (ir, ic) in self.selection:
                cluster_ids.add(ir)
                column_ids.add(ic)
            new_domain = Domain([
                self.data.domain[self.columns.values[col]]
                for col in column_ids
            ], self.data.domain.class_vars, self.data.domain.metas)
            selected_data = Values([
                FilterDiscrete(self.cluster_var, [self.clusters[ir]])
                for ir in cluster_ids
            ],
                                   conjunction=False)(self.data)
            selected_data = selected_data.transform(new_domain)
            annotated_data = create_annotated_table(
                self.data.transform(new_domain),
                np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])
        if self.ca is not None and self._task is None:
            table = self.ca.create_contingency_table()
        else:
            table = None
        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.contingency.send(table)

    def _invalidate(self):
        self.selection = self.tableview.get_selection()
        self.commit()

    def send_report(self):
        rows = None
        columns = None
        if self.data is not None:
            rows = self.cluster_var
            if rows in self.data.domain:
                rows = self.data.domain[rows]
            columns = self.columns
            if columns in self.data.domain:
                columns = self.data.domain[columns]
        self.report_items((
            ("Rows", rows),
            ("Columns", columns),
        ))
Beispiel #3
0
class OWDotMatrix(widget.OWWidget):
    name = "Dot Matrix"
    description = "Perform cluster analysis."
    icon = "icons/DotMatrix.svg"
    priority = 410

    class Inputs:
        data = Input("Data", Table, default=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
        contingency = Output("Contingency Table", Table)

    class Error(OWWidget.Error):
        no_discrete_variable = Msg("No discrete variables in data.")

    class Warning(OWWidget.Warning):
        to_many_attributes = Msg("Too many genes on input, first {} genes displayed.")

    GENE_MAXIMUM = 100
    CELL_SIZES = (14, 22, 30)
    AGGREGATE_F = [
        lambda x: np.mean(x, axis=0),
        lambda x: np.median(x, axis=0),
        lambda x: np.min(x, axis=0),
        lambda x: np.max(x, axis=0),
        lambda x: np.mean(x > 0, axis=0),
    ]
    AGGREGATE_NAME = [
        "Mean expression",
        "Median expression",
        "Min expression",
        "Max expression",
        "Fraction expressing"
    ]

    settingsHandler = DomainContextHandler()
    cluster_var = ContextSetting(None)
    aggregate_ix = ContextSetting(0)  # type: int
    biclustering = ContextSetting(True)
    transpose = ContextSetting(False)
    log_scale = ContextSetting(False)
    normalize = ContextSetting(True)
    cell_size_ix = ContextSetting(2)  # type: int
    selection_indices = ContextSetting(set())
    auto_apply = Setting(True)

    want_main_area = True

    def __init__(self):
        super().__init__()

        self.feature_model = DomainModel(valid_types=DiscreteVariable)
        self._init_vars()
        self._set_info_string()

        box = gui.vBox(self.controlArea, "Cluster Variable")
        gui.comboBox(box, self, "cluster_var", sendSelectedValue=True,
                     model=self.feature_model, callback=self._aggregate_data)

        box = gui.vBox(self.controlArea, "Aggregation")
        gui.comboBox(box, self, "aggregate_ix", sendSelectedValue=False,
                     items=self.AGGREGATE_NAME, callback=self._aggregate_data)

        box = gui.vBox(self.controlArea, "Options")
        gui.checkBox(box, self, "biclustering", "Order cells and genes",
                     callback=self._calculate_table_values)
        gui.checkBox(box, self, "transpose", "Transpose",
                     callback=self._calculate_table_values)
        gui.checkBox(box, self, "log_scale", "Log scale",
                     callback=self._calculate_table_values)
        gui.checkBox(box, self, "normalize", "Normalize data",
                     callback=self._calculate_table_values)

        box = gui.vBox(self.controlArea, "Plot Size")
        gui.radioButtons(box, self, "cell_size_ix", btnLabels=("S", "M", "L"),
                         callback=lambda: self.tableview.set_cell_size(self.CELL_SIZES[self.cell_size_ix]),
                         orientation=Qt.Horizontal)

        gui.rubber(self.controlArea)

        self.apply_button = gui.auto_commit(
            self.controlArea, self, "auto_apply", "&Apply", box=False)

        self.tableview = ContingencyTable(self)
        self.mainArea.layout().addWidget(self.tableview)

    def _init_vars(self):
        self.data = None  # type: Table
        self.matrix = None
        self.clusters = None
        self.clusters_unordered = None
        self.aggregated_data = None
        self.cluster_var = None
        self.selected_names = {}

    def _set_info_string(self):
        formatstr = "{} genes\n{} cells\n{} clusters"
        if self.data:
            self.info.set_input_summary(
                str(len(self.data)),
                formatstr.format(len(self.data.domain.attributes), len(self.data), len(self.clusters_unordered)))
        else:
            self.info.set_input_summary(self.info.NoInput)

    @Inputs.data
    @check_sql_input
    def set_data(self, data):
        if self.feature_model:
            self.closeContext()

        self._init_vars()
        self.Error.no_discrete_variable.clear()
        self.Warning.clear()
        self.data = data
        if self.data:
            self.feature_model.set_domain(self.data.domain)
            if self.feature_model:
                self.openContext(self.data)
                if self.cluster_var is None:
                    self.cluster_var = self.feature_model[0]
                self._aggregate_data()
            else:
                self.tableview.clear()
                self.Error.no_discrete_variable()
                self.data = None
                self._set_info_string()
                self.commit()
        else:
            self.tableview.clear()
            self._set_info_string()
            self.commit()

    @staticmethod
    def _group_by(table: Table, var: DiscreteVariable):
        column = table.get_column_view(var)[0]
        return (table[column == value] for value in np.unique(column))

    @staticmethod
    def _transpose(matrix: np.ndarray, clusters, genes):
        return matrix.T, np.array([str(g) for g in genes]), [ContinuousVariable(c) for c in clusters]

    @staticmethod
    def _normalize(matrix: np.ndarray):
        matrix = (matrix - np.mean(matrix, axis=0, keepdims=True)) / (
                np.std(matrix, axis=0, keepdims=True) + 1e-10)
        matrix[matrix < -3] = -3
        matrix[matrix > 3] = 3
        return matrix

    @staticmethod
    def _norm_min_max(matrix: np.ndarray):
        matrix = matrix - matrix.min()
        return matrix / (matrix.max() + 1e-12)

    def _aggregate_data(self):
        self.Warning.clear()
        if self.data is None:
            return

        self.clusters_unordered = np.array(
            [self.cluster_var.values[int(ix)]
             for ix in np.unique(self.data.get_column_view(self.cluster_var)[0])])
        self._set_info_string()

        if len(self.data.domain.attributes) > self.GENE_MAXIMUM:
            self.Warning.to_many_attributes(self.GENE_MAXIMUM)

        self.aggregated_data = np.stack([self.AGGREGATE_F[self.aggregate_ix](cluster.X[:, :self.GENE_MAXIMUM])
                                         for cluster in self._group_by(self.data, self.cluster_var)],
                                        axis=0)
        self._calculate_table_values()

    def _calculate_table_values(self):
        genes = self.data.domain.attributes[:self.GENE_MAXIMUM]
        matrix = self.aggregated_data
        clusters = self.clusters_unordered
        if self.transpose:
            matrix, clusters, genes = self._transpose(matrix, clusters, genes)

        # create data table since imputation of nan values is required
        matrix = Table(Domain(genes), matrix)
        matrix_before_norm = matrix.copy()  # for tooltip
        matrix = SklImpute()(matrix)

        if self.log_scale:
            matrix.X = np.log(matrix.X + 1)
        if self.normalize:
            matrix.X = self._normalize(matrix.X)

        # values must be in range [0, 1] for visualisation
        matrix.X = self._norm_min_max(matrix.X)

        if self.biclustering:
            cluster_order, gene_order = self.cluster_data(matrix)
        else:
            cluster_order, gene_order = np.arange(matrix.X.shape[0]), np.arange(matrix.X.shape[1])

        # reorder
        self.matrix = matrix[cluster_order][:, gene_order]
        self.matrix_before_norm = matrix_before_norm[cluster_order][:, gene_order]
        self.clusters = clusters[cluster_order]

        self._refresh_table()
        self._update_selection()
        self._invalidate()

    def cluster_data(self, matrix):
        with self.progressBar():
            # cluster rows
            if len(matrix) > 1:
                rows_distances = Euclidean(matrix)
                cluster = hierarchical.dist_matrix_clustering(rows_distances)
                row_order = hierarchical.optimal_leaf_ordering(
                    cluster, rows_distances, progress_callback=self.progressBarSet)
                row_order = np.array([x.value.index for x in leaves(row_order)])
            else:
                row_order = np.array([0])

            # cluster columns
            if matrix.X.shape[1] > 1:
                columns_distances = Euclidean(matrix, axis=0)
                cluster = hierarchical.dist_matrix_clustering(columns_distances)
                columns_order = hierarchical.optimal_leaf_ordering(
                    cluster, columns_distances,
                    progress_callback=self.progressBarSet)
                columns_order = np.array([x.value.index for x in leaves(columns_order)])
            else:
                columns_order = np.array([0])
        return row_order, columns_order

    def _refresh_table(self):
        if self.matrix is None:
            return

        columns = np.array([str(x) for x in self.matrix.domain.attributes])
        rows = self.clusters
        # row_order, column_order = self.gene_order, self.cluster_order
        self.tableview.set_headers(rows, columns, circles=True,
                                   cell_size=self.CELL_SIZES[self.cell_size_ix], bold_headers=False)
        if self.matrix.X.size > 0:
            matrix = self.matrix.X

            def tooltip(i, j):
                cluster, gene, value = rows[i], columns[j], self.matrix_before_norm[i, j]
                return "Cluster: {}\nGene: {}\n{}: {:.1f}".format(
                    cluster, gene, self.AGGREGATE_NAME[self.aggregate_ix], value)

            self.tableview.update_table(matrix, tooltip=tooltip)

    def commit(self):
        if self.data is None:
            self.Outputs.selected_data.send(None)
            self.Outputs.annotated_data.send(None)
            self.Outputs.contingency.send(None)
            return

        if len(self.selection_indices):
            cluster_ids = set()
            gene_ids = set()
            for (ir, ic) in self.selection_indices:
                if not self.transpose:
                    cluster_ids.add(ir)
                    gene_ids.add(ic)
                else:
                    cluster_ids.add(ic)
                    gene_ids.add(ir)

            columns = self.clusters if self.transpose else [str(x) for x in self.matrix.domain.attributes]
            rows = self.clusters if not self.transpose else [str(x) for x in self.matrix.domain.attributes]
            new_domain = Domain([self.data.domain[columns[i]] for i in gene_ids],
                                self.data.domain.class_vars,
                                self.data.domain.metas)
            selected_data = Values([FilterDiscrete(self.cluster_var, [rows[i]])
                                    for i in cluster_ids],
                                   conjunction=False)(self.data)
            selected_data = selected_data.transform(new_domain)
            annotated_data = create_annotated_table(self.data,
                                                    np.where(np.in1d(self.data.ids, selected_data.ids, True)))
        else:
            selected_data = None
            annotated_data = create_annotated_table(self.data, [])

        clusters_values = list(set(self.clusters))
        table = ClusterAnalysis.contingency_table(
            self.matrix,
            DiscreteVariable("Gene" if self.transpose else self.cluster_var.name, clusters_values),
            [str(x) for x in self.matrix.domain.attributes],
            [[clusters_values.index(c)] for c in self.clusters]
        )

        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)
        self.Outputs.contingency.send(table)

    def _update_selection(self):
        """
        This function updates widget selection in case when any item is selected.
        It updates selection when order has changed
        """
        rows = self.clusters.tolist()
        columns = [str(x) for x in self.matrix.domain.attributes]
        if self.transpose:
            new_selection = {(rows.index(g), columns.index(c)) for g, c in self.selected_names}
        else:
            new_selection = {(rows.index(c), columns.index(g)) for g, c in self.selected_names}
        self.tableview.set_selection(new_selection)

    def _invalidate(self):
        self.save_selection_names()
        self.commit()

    def save_selection_names(self):
        """
        With this method we save the names of selected genes-clusters pairs, since options changes
        the columns, rows orders and we want to keep the selection.
        """
        self.selection_indices = self.tableview.get_selection()
        genes = self.clusters if self.transpose else [str(x) for x in self.matrix.domain.attributes]
        clusters = self.clusters if not self.transpose else [str(x) for x in self.matrix.domain.attributes]

        if self.transpose:
            self.selected_names = {(genes[g], clusters[c]) for g, c in self.selection_indices}
        else:
            self.selected_names = {(genes[g], clusters[c]) for c, g in self.selection_indices}