Ejemplo n.º 1
0
class OWMap(widget.OWWidget):
    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/GeoMap.svg"
    priority = 100

    inputs = [("Data", Table, "set_data", widget.Default),
              ("Data Subset", Table, "set_subset"),
              ("Learner", Learner, "set_learner")]

    outputs = [("Selected Data", Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Table)]

    replaces = [
        "Orange.widgets.visualize.owmap.OWMap",
    ]

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    autocommit = settings.Setting(True)
    tile_provider = settings.Setting('Black and white')
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    class_attr = settings.ContextSetting('(None)')
    color_attr = settings.ContextSetting('')
    label_attr = settings.ContextSetting('')
    shape_attr = settings.ContextSetting('')
    size_attr = settings.ContextSetting('')
    opacity = settings.Setting(100)
    zoom = settings.Setting(100)
    jittering = settings.Setting(0)
    cluster_points = settings.Setting(False)
    show_legend = settings.Setting(True)

    TILE_PROVIDERS = OrderedDict((
        ('Black and white', 'OpenStreetMap.BlackAndWhite'),
        ('OpenStreetMap', 'OpenStreetMap.Mapnik'),
        ('Topographic', 'Thunderforest.OpenCycleMap'),
        ('Topographic 2', 'Thunderforest.Outdoors'),
        ('Satellite', 'Esri.WorldImagery'),
        ('Print', 'Stamen.TonerLite'),
        ('Dark', 'CartoDB.DarkMatter'),
        ('Watercolor', 'Stamen.Watercolor'),
    ))

    class Error(widget.OWWidget.Error):
        model_error = widget.Msg("Error predicting: {}")
        learner_error = widget.Msg("Error modelling: {}")

    UserAdviceMessages = [
        widget.Message(
            'Select markers by holding <b><kbd>Shift</kbd></b> key and dragging '
            'a rectangle around them. Clear the selection by clicking anywhere.',
            'shift-selection')
    ]

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletMap(self)
        self.mainArea.layout().addWidget(map)
        self.selection = None
        self.data = None
        self.learner = None

        def selectionChanged(indices):
            self.selection = self.data[indices] if self.data is not None and indices else None
            self._indices = indices
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        def _set_map_provider():
            map.set_map_provider(self.TILE_PROVIDERS[self.tile_provider])

        box = gui.vBox(self.controlArea, 'Map')
        gui.comboBox(box, self, 'tile_provider',
                     orientation=Qt.Horizontal,
                     label='Map:',
                     items=tuple(self.TILE_PROVIDERS.keys()),
                     sendSelectedValue=True,
                     callback=_set_map_provider)

        self._latlon_model = DomainModel(
            parent=self, valid_types=ContinuousVariable)
        self._class_model = DomainModel(
            parent=self, placeholder='(None)', valid_types=DomainModel.PRIMITIVE)
        self._color_model = DomainModel(
            parent=self, placeholder='(Same color)', valid_types=DomainModel.PRIMITIVE)
        self._shape_model = DomainModel(
            parent=self, placeholder='(Same shape)', valid_types=DiscreteVariable)
        self._size_model = DomainModel(
            parent=self, placeholder='(Same size)', valid_types=ContinuousVariable)
        self._label_model = DomainModel(
            parent=self, placeholder='(No labels)')

        def _set_lat_long():
            self.map.set_data(self.data, self.lat_attr, self.lon_attr)
            self.train_model()

        self._combo_lat = combo = gui.comboBox(
            box, self, 'lat_attr', orientation=Qt.Horizontal,
            label='Latitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)
        self._combo_lon = combo = gui.comboBox(
            box, self, 'lon_attr', orientation=Qt.Horizontal,
            label='Longitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)

        def _toggle_legend():
            self.map.toggle_legend(self.show_legend)

        gui.checkBox(box, self, 'show_legend', label='Show legend',
                     callback=_toggle_legend)

        box = gui.vBox(self.controlArea, 'Overlay')
        self._combo_class = combo = gui.comboBox(
            box, self, 'class_attr', orientation=Qt.Horizontal,
            label='Target:', sendSelectedValue=True, callback=self.train_model
        )
        self.controls.class_attr.setModel(self._class_model)
        self.set_learner(self.learner)

        box = gui.vBox(self.controlArea, 'Points')
        self._combo_color = combo = gui.comboBox(
            box, self, 'color_attr',
            orientation=Qt.Horizontal,
            label='Color:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_color(self.color_attr))
        combo.setModel(self._color_model)
        self._combo_label = combo = gui.comboBox(
            box, self, 'label_attr',
            orientation=Qt.Horizontal,
            label='Label:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_label(self.label_attr))
        combo.setModel(self._label_model)
        self._combo_shape = combo = gui.comboBox(
            box, self, 'shape_attr',
            orientation=Qt.Horizontal,
            label='Shape:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_shape(self.shape_attr))
        combo.setModel(self._shape_model)
        self._combo_size = combo = gui.comboBox(
            box, self, 'size_attr',
            orientation=Qt.Horizontal,
            label='Size:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_size(self.size_attr))
        combo.setModel(self._size_model)

        def _set_opacity():
            map.set_marker_opacity(self.opacity)

        def _set_zoom():
            map.set_marker_size_coefficient(self.zoom)

        def _set_jittering():
            map.set_jittering(self.jittering)

        def _set_clustering():
            map.set_clustering(self.cluster_points)

        self._opacity_slider = gui.hSlider(
            box, self, 'opacity', None, 1, 100, 5,
            label='Opacity:', labelFormat=' %d%%',
            callback=_set_opacity)
        self._zoom_slider = gui.valueSlider(
            box, self, 'zoom', None, values=(20, 50, 100, 200, 300, 400, 500, 700, 1000),
            label='Symbol size:', labelFormat=' %d%%',
            callback=_set_zoom)
        self._jittering = gui.valueSlider(
            box, self, 'jittering', label='Jittering:', values=(0, .5, 1, 2, 5),
            labelFormat=' %.1f%%', ticks=True,
            callback=_set_jittering)
        self._clustering_check = gui.checkBox(
            box, self, 'cluster_points', label='Cluster points',
            callback=_set_clustering)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        QTimer.singleShot(0, _set_map_provider)
        QTimer.singleShot(0, _toggle_legend)
        QTimer.singleShot(0, _set_opacity)
        QTimer.singleShot(0, _set_zoom)
        QTimer.singleShot(0, _set_jittering)
        QTimer.singleShot(0, _set_clustering)

    autocommit = settings.Setting(True)

    def __del__(self):
        self.progressBarFinished(None)
        self.map = None

    def commit(self):
        self.send('Selected Data', self.selection)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.data, self._indices))

    def set_data(self, data):
        self.data = data

        self.closeContext()

        if data is None or not len(data):
            return self.clear()

        domain = data is not None and data.domain
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(domain)

        lat, lon = find_lat_lon(data)
        if lat or lon:
            self._combo_lat.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name
            self.lon_attr = lon.name

        if data.domain.class_var:
            self.color_attr = data.domain.class_var.name
        elif len(self._color_model):
            self._combo_color.setCurrentIndex(0)
        if len(self._shape_model):
            self._combo_shape.setCurrentIndex(0)
        if len(self._size_model):
            self._combo_size.setCurrentIndex(0)
        if len(self._label_model):
            self._combo_label.setCurrentIndex(0)
        if len(self._class_model):
            self._combo_class.setCurrentIndex(0)

        self.openContext(data)

        self.map.set_data(self.data, self.lat_attr, self.lon_attr)
        self.map.set_marker_color(self.color_attr, update=False)
        self.map.set_marker_label(self.label_attr, update=False)
        self.map.set_marker_shape(self.shape_attr, update=False)
        self.map.set_marker_size(self.size_attr, update=True)

    def set_subset(self, subset):
        self.map.set_subset_ids(subset.ids if subset is not None else np.array([]))

    def handleNewSignals(self):
        super().handleNewSignals()
        self.train_model()

    def set_learner(self, learner):
        self.learner = learner
        self.controls.class_attr.setEnabled(learner is not None)
        self.controls.class_attr.setToolTip(
            'Needs a Learner input for modelling.' if learner is None else '')

    def train_model(self):
        model = None
        self.Error.clear()
        if self.data and self.learner and self.class_attr != '(None)':
            domain = self.data.domain
            if self.lat_attr and self.lon_attr and self.class_attr in domain:
                domain = Domain([domain[self.lat_attr], domain[self.lon_attr]],
                                [domain[self.class_attr]])  # I am retarded
                train = Table.from_table(domain, self.data)
                try:
                    model = self.learner(train)
                except Exception as e:
                    self.Error.learner_error(e)
        self.map.set_model(model)

    def disable_some_controls(self, disabled):
        tooltip = (
            "Available when the zoom is close enough to have "
            "<{} points in the viewport.".format(self.map.N_POINTS_PER_ITER)
            if disabled else '')
        for widget in (self._combo_label,
                       self._combo_shape,
                       self._clustering_check):
            widget.setDisabled(disabled)
            widget.setToolTip(tooltip)

    def clear(self):
        self.map.set_data(None, '', '')
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(None)
        self.lat_attr = self.lon_attr = self.class_attr = self.color_attr = \
        self.label_attr = self.shape_attr = self.size_attr = None
Ejemplo n.º 2
0
class OWSelectRows(widget.OWWidget):
    name = "Select Rows"
    id = "Orange.widgets.data.file"
    description = "Select rows from the data based on values of variables."
    icon = "icons/SelectRows.svg"
    priority = 100
    category = "Data"
    keywords = ["filter"]

    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        matching_data = Output("Matching Data", Table, default=True)
        unmatched_data = Output("Unmatched Data", Table)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    want_main_area = False

    settingsHandler = SelectRowsContextHandler()
    conditions = ContextSetting([])
    update_on_change = Setting(True)
    purge_attributes = Setting(False, schema_only=True)
    purge_classes = Setting(False, schema_only=True)
    auto_commit = Setting(True)

    settings_version = 2

    Operators = {
        ContinuousVariable: [
            (FilterContinuous.Equal, "equals"),
            (FilterContinuous.NotEqual, "is not"),
            (FilterContinuous.Less, "is below"),
            (FilterContinuous.LessEqual, "is at most"),
            (FilterContinuous.Greater, "is greater than"),
            (FilterContinuous.GreaterEqual, "is at least"),
            (FilterContinuous.Between, "is between"),
            (FilterContinuous.Outside, "is outside"),
            (FilterContinuous.IsDefined, "is defined"),
        ],
        DiscreteVariable: [(FilterDiscreteType.Equal, "is"),
                           (FilterDiscreteType.NotEqual, "is not"),
                           (FilterDiscreteType.In, "is one of"),
                           (FilterDiscreteType.IsDefined, "is defined")],
        StringVariable: [
            (FilterString.Equal, "equals"),
            (FilterString.NotEqual, "is not"),
            (FilterString.Less, "is before"),
            (FilterString.LessEqual, "is equal or before"),
            (FilterString.Greater, "is after"),
            (FilterString.GreaterEqual, "is equal or after"),
            (FilterString.Between, "is between"),
            (FilterString.Outside, "is outside"),
            (FilterString.Contains, "contains"),
            (FilterString.StartsWith, "begins with"),
            (FilterString.EndsWith, "ends with"),
            (FilterString.IsDefined, "is defined"),
        ]
    }

    Operators[TimeVariable] = Operators[ContinuousVariable]

    AllTypes = {}
    for _all_name, _all_type, _all_ops in (("All variables", 0, [
        (None, "are defined")
    ]), ("All numeric variables", 2, [
        (v, _plural(t)) for v, t in Operators[ContinuousVariable]
    ]), ("All string variables", 3, [(v, _plural(t))
                                     for v, t in Operators[StringVariable]])):
        Operators[_all_name] = _all_ops
        AllTypes[_all_name] = _all_type

    operator_names = {
        vtype: [name for _, name in filters]
        for vtype, filters in Operators.items()
    }

    class Error(widget.OWWidget.Error):
        parsing_error = Msg("{}")

    def __init__(self):
        super().__init__()

        self.old_purge_classes = True

        self.conditions = []
        self.last_output_conditions = None
        self.data = None
        self.data_desc = self.match_desc = self.nonmatch_desc = None
        self.variable_model = DomainModel([
            list(self.AllTypes), DomainModel.Separator, DomainModel.CLASSES,
            DomainModel.ATTRIBUTES, DomainModel.METAS
        ])

        box = gui.vBox(self.controlArea, 'Conditions', stretch=100)
        self.cond_list = QTableWidget(box,
                                      showGrid=False,
                                      selectionMode=QTableWidget.NoSelection)
        box.layout().addWidget(self.cond_list)
        self.cond_list.setColumnCount(4)
        self.cond_list.setRowCount(0)
        self.cond_list.verticalHeader().hide()
        self.cond_list.horizontalHeader().hide()
        for i in range(3):
            self.cond_list.horizontalHeader().setSectionResizeMode(
                i, QHeaderView.Stretch)
        self.cond_list.horizontalHeader().resizeSection(3, 30)
        self.cond_list.viewport().setBackgroundRole(QPalette.Window)

        box2 = gui.hBox(box)
        gui.rubber(box2)
        self.add_button = gui.button(box2,
                                     self,
                                     "Add Condition",
                                     callback=self.add_row)
        self.add_all_button = gui.button(box2,
                                         self,
                                         "Add All Variables",
                                         callback=self.add_all)
        self.remove_all_button = gui.button(box2,
                                            self,
                                            "Remove All",
                                            callback=self.remove_all)
        gui.rubber(box2)

        boxes = gui.widgetBox(self.controlArea, orientation=QHBoxLayout())
        layout = boxes.layout()

        box_setting = gui.vBox(boxes, addToLayout=False, box=True)
        self.cb_pa = gui.checkBox(box_setting,
                                  self,
                                  "purge_attributes",
                                  "Remove unused features",
                                  callback=self.conditions_changed)
        gui.separator(box_setting, height=1)
        self.cb_pc = gui.checkBox(box_setting,
                                  self,
                                  "purge_classes",
                                  "Remove unused classes",
                                  callback=self.conditions_changed)
        layout.addWidget(box_setting, 1)

        self.report_button.setFixedWidth(120)
        gui.rubber(self.buttonsArea.layout())
        layout.addWidget(self.buttonsArea)

        acbox = gui.auto_send(None, self, "auto_commit")
        layout.addWidget(acbox, 1)
        layout.setAlignment(acbox, Qt.AlignBottom)

        self.info.set_input_summary(self.info.NoInput)
        self.info.set_output_summary(self.info.NoOutput)

        self.set_data(None)
        self.resize(600, 400)

    def add_row(self, attr=None, condition_type=None, condition_value=None):
        model = self.cond_list.model()
        row = model.rowCount()
        model.insertRow(row)

        attr_combo = ComboBoxSearch(
            minimumContentsLength=12,
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
        attr_combo.setModel(self.variable_model)
        attr_combo.row = row
        attr_combo.setCurrentIndex(
            self.variable_model.indexOf(attr) if attr else len(self.AllTypes) +
            1)
        self.cond_list.setCellWidget(row, 0, attr_combo)

        index = QPersistentModelIndex(model.index(row, 3))
        temp_button = QPushButton(
            '×',
            self,
            flat=True,
            styleSheet='* {font-size: 16pt; color: silver}'
            '*:hover {color: black}')
        temp_button.clicked.connect(lambda: self.remove_one(index.row()))
        self.cond_list.setCellWidget(row, 3, temp_button)

        self.remove_all_button.setDisabled(False)
        self.set_new_operators(attr_combo, attr is not None, condition_type,
                               condition_value)
        attr_combo.currentIndexChanged.connect(
            lambda _: self.set_new_operators(attr_combo, False))

        self.cond_list.resizeRowToContents(row)

    def add_all(self):
        if self.cond_list.rowCount():
            Mb = QMessageBox
            if Mb.question(
                    self, "Remove existing filters",
                    "This will replace the existing filters with "
                    "filters for all variables.", Mb.Ok | Mb.Cancel) != Mb.Ok:
                return
            self.remove_all()
        for attr in self.variable_model[len(self.AllTypes) + 1:]:
            self.add_row(attr)
        self.conditions_changed()

    def remove_one(self, rownum):
        self.remove_one_row(rownum)
        self.conditions_changed()

    def remove_all(self):
        self.remove_all_rows()
        self.conditions_changed()

    def remove_one_row(self, rownum):
        self.cond_list.removeRow(rownum)
        if self.cond_list.model().rowCount() == 0:
            self.remove_all_button.setDisabled(True)

    def remove_all_rows(self):
        # Disconnect signals to avoid stray emits when changing variable_model
        for row in range(self.cond_list.rowCount()):
            for col in (0, 1):
                widg = self.cond_list.cellWidget(row, col)
                if widg:
                    widg.currentIndexChanged.disconnect()
        self.cond_list.clear()
        self.cond_list.setRowCount(0)
        self.remove_all_button.setDisabled(True)

    def set_new_operators(self,
                          attr_combo,
                          adding_all,
                          selected_index=None,
                          selected_values=None):
        old_combo = self.cond_list.cellWidget(attr_combo.row, 1)
        prev_text = old_combo.currentText() if old_combo else ""
        oper_combo = QComboBox()
        oper_combo.row = attr_combo.row
        oper_combo.attr_combo = attr_combo
        attr_name = attr_combo.currentText()
        if attr_name in self.AllTypes:
            oper_combo.addItems(self.operator_names[attr_name])
        else:
            var = self.data.domain[attr_name]
            oper_combo.addItems(self.operator_names[type(var)])
        if selected_index is None:
            selected_index = oper_combo.findText(prev_text)
            if selected_index == -1:
                selected_index = 0
        oper_combo.setCurrentIndex(selected_index)
        self.cond_list.setCellWidget(oper_combo.row, 1, oper_combo)
        self.set_new_values(oper_combo, adding_all, selected_values)
        oper_combo.currentIndexChanged.connect(
            lambda _: self.set_new_values(oper_combo, False))

    @staticmethod
    def _get_lineedit_contents(box):
        contents = []
        for child in getattr(box, "controls", [box]):
            if isinstance(child, QLineEdit):
                contents.append(child.text())
            elif isinstance(child, DateTimeWidget):
                if child.format == (0, 1):
                    contents.append(child.time())
                elif child.format == (1, 0):
                    contents.append(child.date())
                elif child.format == (1, 1):
                    contents.append(child.dateTime())
        return contents

    @staticmethod
    def _get_value_contents(box):
        cont = []
        names = []
        for child in getattr(box, "controls", [box]):
            if isinstance(child, QLineEdit):
                cont.append(child.text())
            elif isinstance(child, QComboBox):
                cont.append(child.currentIndex())
            elif isinstance(child, QToolButton):
                if child.popup is not None:
                    model = child.popup.list_view.model()
                    for row in range(model.rowCount()):
                        item = model.item(row)
                        if item.checkState():
                            cont.append(row + 1)
                            names.append(item.text())
                    child.desc_text = ', '.join(names)
                    child.set_text()
            elif isinstance(child, DateTimeWidget):
                if child.format == (0, 1):
                    cont.append(child.time())
                elif child.format == (1, 0):
                    cont.append(child.date())
                elif child.format == (1, 1):
                    cont.append(child.dateTime())
            elif isinstance(child, QLabel) or child is None:
                pass
            else:
                raise TypeError('Type %s not supported.' % type(child))
        return tuple(cont)

    class QDoubleValidatorEmpty(QDoubleValidator):
        def validate(self, input_, pos):
            if not input_:
                return QDoubleValidator.Acceptable, input_, pos
            if self.locale().groupSeparator() in input_:
                return QDoubleValidator.Invalid, input_, pos
            return super().validate(input_, pos)

    def set_new_values(self, oper_combo, adding_all, selected_values=None):
        # def remove_children():
        #     for child in box.children()[1:]:
        #         box.layout().removeWidget(child)
        #         child.setParent(None)

        def add_textual(contents):
            le = gui.lineEdit(box,
                              self,
                              None,
                              sizePolicy=QSizePolicy(QSizePolicy.Expanding,
                                                     QSizePolicy.Expanding))
            if contents:
                le.setText(contents)
            le.setAlignment(Qt.AlignRight)
            le.editingFinished.connect(self.conditions_changed)
            return le

        def add_numeric(contents):
            le = add_textual(contents)
            le.setValidator(OWSelectRows.QDoubleValidatorEmpty())
            return le

        box = self.cond_list.cellWidget(oper_combo.row, 2)
        lc = ["", ""]
        oper = oper_combo.currentIndex()
        attr_name = oper_combo.attr_combo.currentText()
        if attr_name in self.AllTypes:
            vtype = self.AllTypes[attr_name]
            var = None
        else:
            var = self.data.domain[attr_name]
            var_idx = self.data.domain.index(attr_name)
            vtype = vartype(var)
            if selected_values is not None:
                lc = list(selected_values) + ["", ""]
                lc = [str(x) if vtype != 4 else x for x in lc[:2]]
        if box and vtype == box.var_type:
            lc = self._get_lineedit_contents(box) + lc

        if oper_combo.currentText().endswith(" defined"):
            label = QLabel()
            label.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, label)
        elif var is not None and var.is_discrete:
            if oper_combo.currentText().endswith(" one of"):
                if selected_values:
                    lc = list(selected_values)
                button = DropDownToolButton(self, var, lc)
                button.var_type = vtype
                self.cond_list.setCellWidget(oper_combo.row, 2, button)
            else:
                combo = ComboBoxSearch()
                combo.addItems(("", ) + var.values)
                if lc[0]:
                    combo.setCurrentIndex(int(lc[0]))
                else:
                    combo.setCurrentIndex(0)
                combo.var_type = vartype(var)
                self.cond_list.setCellWidget(oper_combo.row, 2, combo)
                combo.currentIndexChanged.connect(self.conditions_changed)
        else:
            box = gui.hBox(self, addToLayout=False)
            box.var_type = vtype
            self.cond_list.setCellWidget(oper_combo.row, 2, box)
            if vtype == 2:  # continuous:
                box.controls = [add_numeric(lc[0])]
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_numeric(lc[1]))
            elif vtype == 3:  # string:
                box.controls = [add_textual(lc[0])]
                if oper in [6, 7]:
                    gui.widgetLabel(box, " and ")
                    box.controls.append(add_textual(lc[1]))
            elif vtype == 4:  # time:

                def invalidate_datetime():
                    if w_:
                        if w.dateTime() > w_.dateTime():
                            w_.setDateTime(w.dateTime())
                        if w.format == (1, 1):
                            w.calendarWidget.timeedit.setTime(w.time())
                            w_.calendarWidget.timeedit.setTime(w_.time())
                    elif w.format == (1, 1):
                        w.calendarWidget.timeedit.setTime(w.time())

                def datetime_changed():
                    self.conditions_changed()
                    invalidate_datetime()

                datetime_format = (var.have_date, var.have_time)
                column = self.data.get_column_view(var_idx)[0]
                w = DateTimeWidget(self, column, datetime_format)
                w.set_datetime(lc[0])
                box.controls = [w]
                box.layout().addWidget(w)
                w.dateTimeChanged.connect(datetime_changed)
                if oper > 5:
                    gui.widgetLabel(box, " and ")
                    w_ = DateTimeWidget(self, column, datetime_format)
                    w_.set_datetime(lc[1])
                    box.layout().addWidget(w_)
                    box.controls.append(w_)
                    invalidate_datetime()
                    w_.dateTimeChanged.connect(datetime_changed)
                else:
                    w_ = None
            else:
                box.controls = []
        if not adding_all:
            self.conditions_changed()

    @Inputs.data
    def set_data(self, data):
        self.closeContext()
        self.data = data
        self.cb_pa.setEnabled(not isinstance(data, SqlTable))
        self.cb_pc.setEnabled(not isinstance(data, SqlTable))
        self.remove_all_rows()
        self.add_button.setDisabled(data is None)
        self.add_all_button.setDisabled(
            data is None
            or len(data.domain.variables) + len(data.domain.metas) > 100)
        if not data:
            self.info.set_input_summary(self.info.NoInput)
            self.data_desc = None
            self.variable_model.set_domain(None)
            self.commit()
            return
        self.data_desc = report.describe_data_brief(data)
        self.variable_model.set_domain(data.domain)

        self.conditions = []
        self.openContext(data)
        for attr, cond_type, cond_value in self.conditions:
            if attr in self.variable_model:
                self.add_row(attr, cond_type, cond_value)
        if not self.cond_list.model().rowCount():
            self.add_row()

        self.info.set_input_summary(data.approx_len(),
                                    format_summary_details(data))
        self.unconditional_commit()

    def conditions_changed(self):
        try:
            cells_by_rows = ([
                self.cond_list.cellWidget(row, col) for col in range(3)
            ] for row in range(self.cond_list.rowCount()))
            self.conditions = [
                (var_cell.currentData(gui.TableVariable)
                 or var_cell.currentText(), oper_cell.currentIndex(),
                 self._get_value_contents(val_cell))
                for var_cell, oper_cell, val_cell in cells_by_rows
            ]
            if self.update_on_change and (
                    self.last_output_conditions is None
                    or self.last_output_conditions != self.conditions):
                self.commit()
        except AttributeError:
            # Attribute error appears if the signal is triggered when the
            # controls are being constructed
            pass

    @staticmethod
    def _values_to_floats(attr, values):
        if len(values) == 0:
            return values
        if not all(values):
            return None
        if isinstance(attr, TimeVariable):
            values = (value.toString(format=Qt.ISODate) for value in values)
            parse = lambda x: (attr.parse(x), True)
        else:
            parse = QLocale().toDouble

        try:
            floats, ok = zip(*[parse(v) for v in values])
            if not all(ok):
                raise ValueError('Some values could not be parsed as floats'
                                 'in the current locale: {}'.format(values))
        except TypeError:
            floats = values  # values already floats
        assert all(isinstance(v, float) for v in floats)
        return floats

    def commit(self):
        matching_output = self.data
        non_matching_output = None
        annotated_output = None

        self.Error.clear()
        if self.data:
            domain = self.data.domain
            conditions = []
            for attr_name, oper_idx, values in self.conditions:
                if attr_name in self.AllTypes:
                    attr_index = attr = None
                    attr_type = self.AllTypes[attr_name]
                    operators = self.Operators[attr_name]
                else:
                    attr_index = domain.index(attr_name)
                    attr = domain[attr_index]
                    attr_type = vartype(attr)
                    operators = self.Operators[type(attr)]
                opertype, _ = operators[oper_idx]
                if attr_type == 0:
                    filt = data_filter.IsDefined()
                elif attr_type in (2, 4):  # continuous, time
                    try:
                        floats = self._values_to_floats(attr, values)
                    except ValueError as e:
                        self.Error.parsing_error(e.args[0])
                        return
                    if floats is None:
                        continue
                    filt = data_filter.FilterContinuous(
                        attr_index, opertype, *floats)
                elif attr_type == 3:  # string
                    filt = data_filter.FilterString(attr_index, opertype,
                                                    *[str(v) for v in values])
                else:
                    if opertype == FilterDiscreteType.IsDefined:
                        f_values = None
                    else:
                        if not values or not values[0]:
                            continue
                        values = [attr.values[i - 1] for i in values]
                        if opertype == FilterDiscreteType.Equal:
                            f_values = {values[0]}
                        elif opertype == FilterDiscreteType.NotEqual:
                            f_values = set(attr.values)
                            f_values.remove(values[0])
                        elif opertype == FilterDiscreteType.In:
                            f_values = set(values)
                        else:
                            raise ValueError("invalid operand")
                    filt = data_filter.FilterDiscrete(attr_index, f_values)
                conditions.append(filt)

            if conditions:
                filters = data_filter.Values(conditions)
                matching_output = filters(self.data)
                filters.negate = True
                non_matching_output = filters(self.data)

                row_sel = np.in1d(self.data.ids, matching_output.ids)
                annotated_output = create_annotated_table(self.data, row_sel)

            # if hasattr(self.data, "name"):
            #     matching_output.name = self.data.name
            #     non_matching_output.name = self.data.name

            purge_attrs = self.purge_attributes
            purge_classes = self.purge_classes
            if (purge_attrs or purge_classes) and \
                    not isinstance(self.data, SqlTable):
                attr_flags = sum([
                    Remove.RemoveConstant * purge_attrs,
                    Remove.RemoveUnusedValues * purge_attrs
                ])
                class_flags = sum([
                    Remove.RemoveConstant * purge_classes,
                    Remove.RemoveUnusedValues * purge_classes
                ])
                # same settings used for attributes and meta features
                remover = Remove(attr_flags, class_flags, attr_flags)

                matching_output = remover(matching_output)
                non_matching_output = remover(non_matching_output)
                annotated_output = remover(annotated_output)

        if not matching_output:
            matching_output = None
        if not non_matching_output:
            non_matching_output = None
        if not annotated_output:
            annotated_output = None

        self.Outputs.matching_data.send(matching_output)
        self.Outputs.unmatched_data.send(non_matching_output)
        self.Outputs.annotated_data.send(annotated_output)

        self.match_desc = report.describe_data_brief(matching_output)
        self.nonmatch_desc = report.describe_data_brief(non_matching_output)

        summary = matching_output.approx_len() if matching_output else \
            self.info.NoOutput
        details = format_summary_details(
            matching_output) if matching_output else ""
        self.info.set_output_summary(summary, details)

    def send_report(self):
        if not self.data:
            self.report_paragraph("No data.")
            return

        pdesc = None
        describe_domain = False
        for d in (self.data_desc, self.match_desc, self.nonmatch_desc):
            if not d or not d["Data instances"]:
                continue
            ndesc = d.copy()
            del ndesc["Data instances"]
            if pdesc is not None and pdesc != ndesc:
                describe_domain = True
            pdesc = ndesc

        conditions = []
        for attr, oper, values in self.conditions:
            if isinstance(attr, str):
                attr_name = attr
                var_type = self.AllTypes[attr]
                names = self.operator_names[attr_name]
            else:
                attr_name = attr.name
                var_type = vartype(attr)
                names = self.operator_names[type(attr)]
            name = names[oper]
            if oper == len(names) - 1:
                conditions.append("{} {}".format(attr_name, name))
            elif var_type == 1:  # discrete
                if name == "is one of":
                    valnames = [attr.values[v - 1] for v in values]
                    if not valnames:
                        continue
                    if len(valnames) == 1:
                        valstr = valnames[0]
                    else:
                        valstr = f"{', '.join(valnames[:-1])} or {valnames[-1]}"
                    conditions.append(f"{attr} is {valstr}")
                elif values and values[0]:
                    value = values[0] - 1
                    conditions.append(f"{attr} {name} {attr.values[value]}")
            elif var_type == 3:  # string variable
                conditions.append(
                    f"{attr} {name} {' and '.join(map(repr, values))}")
            elif var_type == 4:  # time
                values = (value.toString(format=Qt.ISODate)
                          for value in values)
                conditions.append(f"{attr} {name} {' and '.join(values)}")
            elif all(x for x in values):  # numeric variable
                conditions.append(f"{attr} {name} {' and '.join(values)}")
        items = OrderedDict()
        if describe_domain:
            items.update(self.data_desc)
        else:
            items["Instances"] = self.data_desc["Data instances"]
        items["Condition"] = " AND ".join(conditions) or "no conditions"
        self.report_items("Data", items)
        if describe_domain:
            self.report_items("Matching data", self.match_desc)
            self.report_items("Non-matching data", self.nonmatch_desc)
        else:
            match_inst = \
                bool(self.match_desc) and \
                self.match_desc["Data instances"]
            nonmatch_inst = \
                bool(self.nonmatch_desc) and \
                self.nonmatch_desc["Data instances"]
            self.report_items(
                "Output",
                (("Matching data",
                  "{} instances".format(match_inst) if match_inst else "None"),
                 ("Non-matching data", nonmatch_inst > 0
                  and "{} instances".format(nonmatch_inst))))
Ejemplo n.º 3
0
class OWChoropleth(widget.OWWidget):
    name = 'Choropleth Map'
    description = 'A thematic map in which areas are shaded in proportion ' \
                  'to the measurement of the statistical variable being displayed.'
    icon = "icons/Choropleth.svg"
    priority = 120

    class Inputs:
        data = Input("Data", Table, default=True)

    class Outputs:
        selected_data = Output("Selected Data", Table, default=True)
        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    AGG_FUNCS = (
        'Count',
        'Count defined',
        'Sum',
        'Mean',
        'Median',
        'Mode',
        'Max',
        'Min',
        'Std',
    )
    AGG_FUNCS_TRANSFORM = {
        'Count': 'size',
        'Count defined': 'count',
        'Mode': lambda x: stats.mode(x, nan_policy='omit').mode[0],
    }
    AGG_FUNCS_DISCRETE = ('Count', 'Count defined', 'Mode')
    AGG_FUNCS_CANT_TIME = ('Count', 'Count defined', 'Sum', 'Std')

    autocommit = settings.Setting(True)
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    attr = settings.ContextSetting('')
    agg_func = settings.ContextSetting(AGG_FUNCS[0])
    admin = settings.Setting(0)
    opacity = settings.Setting(70)
    color_steps = settings.Setting(5)
    color_quantization = settings.Setting('equidistant')
    show_labels = settings.Setting(True)
    show_legend = settings.Setting(True)
    show_details = settings.Setting(True)
    selection = settings.ContextSetting([])

    class Error(widget.OWWidget.Error):
        aggregation_discrete = widget.Msg("Only certain types of aggregation defined on categorical attributes: {}")

    class Warning(widget.OWWidget.Warning):
        logarithmic_nonpositive = widget.Msg("Logarithmic quantization requires all values > 0. Using 'equidistant' quantization instead.")

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletChoropleth(self)
        self.mainArea.layout().addWidget(map)
        self.selection = []
        self.data = None
        self.latlon = None
        self.result_min_nonpositive = False
        self._should_fit_bounds = False

        def selectionChanged(selection):
            self._indices = self.ids.isin(selection).nonzero()[0]
            self.selection = selection
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        box = gui.vBox(self.controlArea, 'Aggregation')

        self._latlon_model = DomainModel(parent=self, valid_types=ContinuousVariable)
        self._combo_lat = combo = gui.comboBox(
            box, self, 'lat_attr', orientation=Qt.Horizontal,
            label='Latitude:', sendSelectedValue=True, callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_lon = combo = gui.comboBox(
            box, self, 'lon_attr', orientation=Qt.Horizontal,
            label='Longitude:', sendSelectedValue=True, callback=self.aggregate)
        combo.setModel(self._latlon_model)

        self._combo_attr = combo = gui.comboBox(
            box, self, 'attr', orientation=Qt.Horizontal,
            label='Attribute:', sendSelectedValue=True, callback=self.aggregate)
        combo.setModel(DomainModel(parent=self, valid_types=(ContinuousVariable, DiscreteVariable)))

        gui.comboBox(
            box, self, 'agg_func', orientation=Qt.Horizontal, items=self.AGG_FUNCS,
            label='Aggregation:', sendSelectedValue=True, callback=self.aggregate)

        self._detail_slider = gui.hSlider(
            box, self, 'admin', None, 0, 2, 1,
            label='Administrative level:', labelFormat=' %d',
            callback=self.aggregate)

        box = gui.vBox(self.controlArea, 'Visualization')

        gui.spin(box, self, 'color_steps', 3, 15, 1, label='Color steps:',
                 callback=lambda: self.map.set_color_steps(self.color_steps))

        def _set_quantization():
            self.Warning.logarithmic_nonpositive(
                shown=(self.color_quantization.startswith('log') and
                       self.result_min_nonpositive))
            self.map.set_quantization(self.color_quantization)

        gui.comboBox(box, self, 'color_quantization', label='Color quantization:',
                     orientation=Qt.Horizontal, sendSelectedValue=True,
                     items=('equidistant', 'logarithmic', 'quantile', 'k-means'),
                     callback=_set_quantization)

        self._opacity_slider = gui.hSlider(
            box, self, 'opacity', None, 20, 100, 5,
            label='Opacity:', labelFormat=' %d%%',
            callback=lambda: self.map.set_opacity(self.opacity))

        gui.checkBox(box, self, 'show_legend', label='Show legend',
                     callback=lambda: self.map.toggle_legend(self.show_legend))
        gui.checkBox(box, self, 'show_labels', label='Show map labels',
                     callback=lambda: self.map.toggle_map_labels(self.show_labels))
        gui.checkBox(box, self, 'show_details', label='Show region details in tooltip',
                     callback=lambda: self.map.toggle_tooltip_details(self.show_details))

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        self.map.toggle_legend(self.show_legend)
        self.map.toggle_map_labels(self.show_labels)
        self.map.toggle_tooltip_details(self.show_details)
        self.map.set_quantization(self.color_quantization)
        self.map.set_color_steps(self.color_steps)
        self.map.set_opacity(self.opacity)

    def __del__(self):
        self.map = None

    def commit(self):
        if self.data is not None and self.selection:
            selected_data = self.data[self._indices]
            annotated_data = create_annotated_table(self.data, self._indices)
        else:
            selected_data = annotated_data = None

        self.Outputs.selected_data.send(selected_data)
        self.Outputs.annotated_data.send(annotated_data)

    @Inputs.data
    def set_data(self, data):
        self.data = data

        self.closeContext()

        self.clear()
        
        if data is None:
            return

        self._combo_attr.model().set_domain(data.domain)
        self._latlon_model.set_domain(data.domain)

        lat, lon = find_lat_lon(data)
        if lat or lon:
            self._combo_lat.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name if lat else None
            self.lon_attr = lon.name if lon else None
            if lat and lon:
                self.latlon = np.c_[self.data.get_column_view(self.lat_attr)[0],
                                    self.data.get_column_view(self.lon_attr)[0]]

        if data.domain.class_var:
            self.attr = data.domain.class_var.name
        else:
            self.attr = self._combo_attr.itemText(0)

        self.openContext(data)

        if self.selection:
            self.map.preset_region_selection(self.selection)
        self.aggregate()

        self.map.set_opacity(self.opacity)

        if self.isVisible():
            self.map.fit_to_bounds()
        else:
            self._should_fit_bounds = True

    def showEvent(self, event):
        super().showEvent(event)
        if self._should_fit_bounds:
            QTimer.singleShot(500, self.map.fit_to_bounds)
            self._should_fit_bounds = False

    def aggregate(self):
        if self.latlon is None or self.attr not in self.data.domain:
            self.clear(caches=False)
            return

        attr = self.data.domain[self.attr]

        if attr.is_discrete and self.agg_func not in self.AGG_FUNCS_DISCRETE:
            self.Error.aggregation_discrete(', '.join(map(str.lower, self.AGG_FUNCS_DISCRETE)))
            self.Warning.logarithmic_nonpositive.clear()
            self.clear(caches=False)
            return
        else:
            self.Error.aggregation_discrete.clear()

        try:
            regions, adm0, result, self.map.bounds = \
                self.get_grouped(self.lat_attr, self.lon_attr, self.admin, self.attr, self.agg_func)
        except ValueError:
            # This might happen if widget scheme File→Choropleth, and
            # some attr is selected in choropleth, and then the same attr
            # is set to string attr in File and dataset reloaded.
            # Our "dataflow" arch can suck my balls
            return

        # Only show discrete values that are contained in aggregated results
        discrete_values = []
        if attr.is_discrete and not self.agg_func.startswith('Count'):
            subset = sorted(result.drop_duplicates().dropna().astype(int))
            discrete_values = np.array(attr.values)[subset].tolist()
            discrete_colors = np.array(attr.colors)[subset].tolist()
            result.replace(subset, list(range(len(subset))), inplace=True)

        self.result_min_nonpositive = attr.is_continuous and result.min() <= 0
        force_quantization = self.color_quantization.startswith('log') and self.result_min_nonpositive
        self.Warning.logarithmic_nonpositive(shown=force_quantization)

        repr_time = isinstance(attr, TimeVariable) and self.agg_func not in self.AGG_FUNCS_CANT_TIME

        self.map.exposeObject(
            'results',
            dict(discrete=discrete_values,
                 colors=[color_to_hex(i)
                         for i in (discrete_colors if discrete_values else
                                   ((0, 0, 255), (255, 255, 0)) if attr.is_discrete else
                                   attr.colors[:-1])],  # ???
                 regions=list(adm0),
                 attr=attr.name,
                 have_nonpositive=self.result_min_nonpositive or bool(discrete_values),
                 values=result.to_dict(),
                 repr_vals=result.map(attr.repr_val).to_dict() if repr_time else {},
                 minmax=([result.min(), result.max()] if attr.is_discrete and not discrete_values else
                         [attr.repr_val(result.min()), attr.repr_val(result.max())] if repr_time or not discrete_values else
                         [])))

        self.map.evalJS('replot();')

    @memoize_method(3)
    def get_regions(self, lat_attr, lon_attr, admin):
        latlon = np.c_[self.data.get_column_view(lat_attr)[0],
                       self.data.get_column_view(lon_attr)[0]]
        regions = latlon2region(latlon, admin)
        adm0 = ({'0'} if admin == 0 else
                {'1-' + a3 for a3 in (i.get('adm0_a3') for i in regions) if a3} if admin == 1 else
                {('2-' if a3 in ADMIN2_COUNTRIES else '1-') + a3
                 for a3 in (i.get('adm0_a3') for i in regions) if a3})
        ids = [i.get('_id') for i in regions]
        self.ids = pd.Series(ids)
        regions = set(ids) - {None}
        bounds = get_bounding_rect(regions) if regions else None
        return regions, ids, adm0, bounds

    @memoize_method(6)
    def get_grouped(self, lat_attr, lon_attr, admin, attr, agg_func):
        log.debug('Grouping %s(%s) by (%s, %s; admin%d)',
                  agg_func, attr, lat_attr,  lon_attr, admin)
        regions, ids, adm0, bounds = self.get_regions(lat_attr, lon_attr, admin)
        attr = self.data.domain[attr]
        result = pd.Series(self.data.get_column_view(attr)[0], dtype=float)\
            .groupby(ids)\
            .agg(self.AGG_FUNCS_TRANSFORM.get(agg_func, agg_func.lower()))
        return regions, adm0, result, bounds

    def clear(self, caches=True):
        if caches:
            try:
                self.get_regions.cache_clear()
                self.get_grouped.cache_clear()
            except AttributeError:
                pass  # back-compat https://github.com/biolab/orange3/pull/2229
        self.selection = []
        self.map.exposeObject('results', {})
        self.map.evalJS('replot();')
Ejemplo n.º 4
0
class OWMap(widget.OWWidget):
    name = 'Geo Map'
    description = 'Show data points on a world map.'
    icon = "icons/Map.svg"

    inputs = [("Data", Table, "set_data", widget.Default),
              ("Data Subset", Table, "set_subset"),
              ("Learner", Learner, "set_learner")]

    outputs = [("Selected Data", Table, widget.Default),
               (ANNOTATED_DATA_SIGNAL_NAME, Table)]

    settingsHandler = settings.DomainContextHandler()

    want_main_area = True

    autocommit = settings.Setting(True)
    tile_provider = settings.Setting('Black and white')
    lat_attr = settings.ContextSetting('')
    lon_attr = settings.ContextSetting('')
    class_attr = settings.ContextSetting('(None)')
    color_attr = settings.ContextSetting('')
    label_attr = settings.ContextSetting('')
    shape_attr = settings.ContextSetting('')
    size_attr = settings.ContextSetting('')
    opacity = settings.Setting(100)
    zoom = settings.Setting(100)
    jittering = settings.Setting(0)
    cluster_points = settings.Setting(False)
    show_legend = settings.Setting(True)

    TILE_PROVIDERS = OrderedDict((
        ('Black and white', 'OpenStreetMap.BlackAndWhite'),
        ('OpenStreetMap', 'OpenStreetMap.Mapnik'),
        ('Topographic', 'Thunderforest.OpenCycleMap'),
        ('Topographic 2', 'Thunderforest.Outdoors'),
        ('Satellite', 'Esri.WorldImagery'),
        ('Print', 'Stamen.TonerLite'),
        ('Dark', 'CartoDB.DarkMatter'),
        ('Watercolor', 'Stamen.Watercolor'),
    ))

    class Error(widget.OWWidget.Error):
        model_error = widget.Msg("Error predicting: {}")
        learner_error = widget.Msg("Error modelling: {}")

    UserAdviceMessages = [
        widget.Message(
            'Select markers by holding <b><kbd>Shift</kbd></b> key and dragging '
            'a rectangle around them. Clear the selection by clicking anywhere.',
            'shift-selection')
    ]

    graph_name = "map"

    def __init__(self):
        super().__init__()
        self.map = map = LeafletMap(self)  # type: LeafletMap
        self.mainArea.layout().addWidget(map)
        self.selection = None
        self.data = None
        self.learner = None

        def selectionChanged(indices):
            self.selection = self.data[indices] if self.data is not None and indices else None
            self._indices = indices
            self.commit()

        map.selectionChanged.connect(selectionChanged)

        def _set_map_provider():
            map.set_map_provider(self.TILE_PROVIDERS[self.tile_provider])

        box = gui.vBox(self.controlArea, 'Map')
        gui.comboBox(box, self, 'tile_provider',
                     orientation=Qt.Horizontal,
                     label='Map:',
                     items=tuple(self.TILE_PROVIDERS.keys()),
                     sendSelectedValue=True,
                     callback=_set_map_provider)

        self._latlon_model = DomainModel(
            parent=self, valid_types=ContinuousVariable)
        self._class_model = DomainModel(
            parent=self, placeholder='(None)', valid_types=DomainModel.PRIMITIVE)
        self._color_model = DomainModel(
            parent=self, placeholder='(Same color)', valid_types=DomainModel.PRIMITIVE)
        self._shape_model = DomainModel(
            parent=self, placeholder='(Same shape)', valid_types=DiscreteVariable)
        self._size_model = DomainModel(
            parent=self, placeholder='(Same size)', valid_types=ContinuousVariable)
        self._label_model = DomainModel(
            parent=self, placeholder='(No labels)')

        def _set_lat_long():
            self.map.set_data(self.data, self.lat_attr, self.lon_attr)
            self.train_model()

        self._combo_lat = combo = gui.comboBox(
            box, self, 'lat_attr', orientation=Qt.Horizontal,
            label='Latitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)
        self._combo_lon = combo = gui.comboBox(
            box, self, 'lon_attr', orientation=Qt.Horizontal,
            label='Longitude:', sendSelectedValue=True, callback=_set_lat_long)
        combo.setModel(self._latlon_model)

        def _toggle_legend():
            self.map.toggle_legend(self.show_legend)

        gui.checkBox(box, self, 'show_legend', label='Show legend',
                     callback=_toggle_legend)

        box = gui.vBox(self.controlArea, 'Overlay')
        self._combo_class = combo = gui.comboBox(
            box, self, 'class_attr', orientation=Qt.Horizontal,
            label='Target:', sendSelectedValue=True, callback=self.train_model
        )
        self.controls.class_attr.setModel(self._class_model)
        self.set_learner(self.learner)

        box = gui.vBox(self.controlArea, 'Points')
        self._combo_color = combo = gui.comboBox(
            box, self, 'color_attr',
            orientation=Qt.Horizontal,
            label='Color:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_color(self.color_attr))
        combo.setModel(self._color_model)
        self._combo_label = combo = gui.comboBox(
            box, self, 'label_attr',
            orientation=Qt.Horizontal,
            label='Label:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_label(self.label_attr))
        combo.setModel(self._label_model)
        self._combo_shape = combo = gui.comboBox(
            box, self, 'shape_attr',
            orientation=Qt.Horizontal,
            label='Shape:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_shape(self.shape_attr))
        combo.setModel(self._shape_model)
        self._combo_size = combo = gui.comboBox(
            box, self, 'size_attr',
            orientation=Qt.Horizontal,
            label='Size:',
            sendSelectedValue=True,
            callback=lambda: self.map.set_marker_size(self.size_attr))
        combo.setModel(self._size_model)

        def _set_opacity():
            map.set_marker_opacity(self.opacity)

        def _set_zoom():
            map.set_marker_size_coefficient(self.zoom)

        def _set_jittering():
            map.set_jittering(self.jittering)

        def _set_clustering():
            map.set_clustering(self.cluster_points)

        self._opacity_slider = gui.hSlider(
            box, self, 'opacity', None, 1, 100, 5,
            label='Opacity:', labelFormat=' %d%%',
            callback=_set_opacity)
        self._zoom_slider = gui.valueSlider(
            box, self, 'zoom', None, values=(20, 50, 100, 200, 300, 400, 500, 700, 1000),
            label='Symbol size:', labelFormat=' %d%%',
            callback=_set_zoom)
        self._jittering = gui.valueSlider(
            box, self, 'jittering', label='Jittering:', values=(0, .5, 1, 2, 5),
            labelFormat=' %.1f%%', ticks=True,
            callback=_set_jittering)
        self._clustering_check = gui.checkBox(
            box, self, 'cluster_points', label='Cluster points',
            callback=_set_clustering)

        gui.rubber(self.controlArea)
        gui.auto_commit(self.controlArea, self, 'autocommit', 'Send Selection')

        QTimer.singleShot(0, _set_map_provider)
        QTimer.singleShot(0, _toggle_legend)
        QTimer.singleShot(0, _set_opacity)
        QTimer.singleShot(0, _set_zoom)
        QTimer.singleShot(0, _set_jittering)
        QTimer.singleShot(0, _set_clustering)

    autocommit = settings.Setting(True)

    def __del__(self):
        self.progressBarFinished(None)
        self.map = None

    def commit(self):
        self.send('Selected Data', self.selection)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.data, self._indices))

    def set_data(self, data):
        self.data = data

        self.closeContext()

        if data is None or not len(data):
            return self.clear()

        all_vars = list(chain(self.data.domain.variables, self.data.domain.metas))

        domain = data is not None and data.domain
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(domain)

        def _find_lat_lon():
            lat_attr = next(
                (attr for attr in all_vars
                 if attr.is_continuous and
                    attr.name.lower().startswith(('latitude', 'lat'))), None)
            lon_attr = next(
                (attr for attr in all_vars
                 if attr.is_continuous and
                    attr.name.lower().startswith(('longitude', 'lng', 'long', 'lon'))), None)

            def _all_between(vals, min, max):
                return np.all((min <= vals) & (vals <= max))

            if not lat_attr:
                for attr in all_vars:
                    if attr.is_continuous:
                        values = np.nan_to_num(data.get_column_view(attr)[0].astype(float))
                        if _all_between(values, -90, 90):
                            lat_attr = attr
                            break
            if not lon_attr:
                for attr in all_vars:
                    if attr.is_continuous:
                        values = np.nan_to_num(data.get_column_view(attr)[0].astype(float))
                        if _all_between(values, -180, 180):
                            lon_attr = attr
                            break

            return lat_attr, lon_attr

        lat, lon = _find_lat_lon()
        if lat or lon:
            self._combo_lat.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lat))
            self._combo_lon.setCurrentIndex(-1 if lat is None else self._latlon_model.indexOf(lon))
            self.lat_attr = lat.name
            self.lon_attr = lon.name

        if data.domain.class_var:
            self.color_attr = data.domain.class_var.name
        elif len(self._color_model):
            self._combo_color.setCurrentIndex(0)
        if len(self._shape_model):
            self._combo_shape.setCurrentIndex(0)
        if len(self._size_model):
            self._combo_size.setCurrentIndex(0)
        if len(self._label_model):
            self._combo_label.setCurrentIndex(0)
        if len(self._class_model):
            self._combo_class.setCurrentIndex(0)

        self.openContext(data)

        self.map.set_data(self.data, self.lat_attr, self.lon_attr)
        self.map.set_marker_color(self.color_attr, update=False)
        self.map.set_marker_label(self.label_attr, update=False)
        self.map.set_marker_shape(self.shape_attr, update=False)
        self.map.set_marker_size(self.size_attr, update=True)

    def set_subset(self, subset):
        self.map.set_subset_ids(subset.ids if subset is not None else np.array([]))

    def handleNewSignals(self):
        super().handleNewSignals()
        self.train_model()

    def set_learner(self, learner):
        self.learner = learner
        self.controls.class_attr.setEnabled(learner is not None)
        self.controls.class_attr.setToolTip(
            'Needs a Learner input for modelling.' if learner is None else '')

    def train_model(self):
        model = None
        self.Error.clear()
        if self.data and self.learner and self.class_attr != '(None)':
            domain = self.data.domain
            if self.lat_attr and self.lon_attr and self.class_attr in domain:
                domain = Domain([domain[self.lat_attr], domain[self.lon_attr]],
                                [domain[self.class_attr]])  # I am retarded
                train = self.data.transform(domain)
                try:
                    model = self.learner(train)
                except Exception as e:
                    self.Error.learner_error(e)
        self.map.set_model(model)

    def disable_some_controls(self, disabled):
        tooltip = (
            "Available when the zoom is close enough to have "
            "<{} points in the viewport.".format(self.map.N_POINTS_PER_ITER)
            if disabled else '')
        for widget in (self._combo_label,
                       self._combo_shape,
                       self._clustering_check):
            widget.setDisabled(disabled)
            widget.setToolTip(tooltip)

    def clear(self):
        self.map.set_data(None, '', '')
        for model in (self._latlon_model,
                      self._class_model,
                      self._color_model,
                      self._shape_model,
                      self._size_model,
                      self._label_model):
            model.set_domain(None)
        self.lat_attr = self.lon_attr = self.class_attr = self.color_attr = \
        self.label_attr = self.shape_attr = self.size_attr = ''
Ejemplo n.º 5
0
class OWGeneSets(OWWidget):
    name = "Gene Sets"
    description = ""
    icon = "icons/OWGeneSets.svg"
    priority = 9
    want_main_area = True

    COUNT, GENES, CATEGORY, TERM = range(4)
    DATA_HEADER_LABELS = ["Count", 'Genes In Set', 'Category', 'Term']

    organism = Setting(None, schema_only=True)
    stored_gene_sets_selection = Setting([], schema_only=True)
    selected_rows = Setting([], schema_only=True)
    custom_gene_set_indicator = Setting(None, schema_only=True)

    min_count = Setting(5)
    use_min_count = Setting(True)
    auto_commit = Setting(True)

    class Inputs:
        genes = Input("Data", Table)
        custom_sets = Input('Custom Gene Sets', Table)

    class Outputs:
        matched_genes = Output("Matched Genes", Table)

    class Information(OWWidget.Information):
        pass

    class Warning(OWWidget.Warning):
        all_sets_filtered = Msg('All sets were filtered out.')

    class Error(OWWidget.Error):
        organism_mismatch = Msg('Organism in input data and custom gene sets does not match')
        missing_annotation = Msg(ERROR_ON_MISSING_ANNOTATION)
        missing_gene_id = Msg(ERROR_ON_MISSING_GENE_ID)
        missing_tax_id = Msg(ERROR_ON_MISSING_TAX_ID)
        cant_reach_host = Msg("Host orange.biolab.si is unreachable.")
        cant_load_organisms = Msg("No available organisms, please check your connection.")

    def __init__(self):
        super().__init__()

        # commit
        self.commit_button = None

        # progress bar
        self.progress_bar = None
        self.progress_bar_iterations = None

        # data
        self.input_data = None
        self.input_genes = []
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gene_id_column = None

        # custom gene sets
        self.custom_data = None
        self.feature_model = DomainModel(valid_types=(DiscreteVariable, StringVariable))
        self.custom_gs_col_box = None
        self.gs_label_combobox = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.num_of_custom_sets = None

        # Gene Sets widget
        self.gs_widget = None

        # info box
        self.input_info = None
        self.num_of_sel_genes = 0

        # filter
        self.line_edit_filter = None
        self.search_pattern = ''
        self.organism_select_combobox = None

        # data model view
        self.data_view = None
        self.data_model = None

        # gene matcher NCBI
        self.gene_matcher = None

        # filter proxy model
        self.filter_proxy_model = None

        # hierarchy widget
        self.hierarchy_widget = None
        self.hierarchy_state = None

        # spinbox
        self.spin_widget = None

        # threads
        self.threadpool = QThreadPool(self)
        self.workers = None

        self._task = None  # type: Optional[Task]
        self._executor = ThreadExecutor()

        # gui
        self.setup_gui()

    def __reset_widget_state(self):
        self.update_info_box()
        # clear data view
        self.init_item_model()
        # reset filters
        self.setup_filter_model()

    def cancel(self):
        """
        Cancel the current task (if any).
        """
        if self._task is not None:
            self._task.cancel()
            assert self._task.future.done()
            # disconnect the `_task_finished` slot
            self._task.watcher.done.disconnect(self._init_gene_sets_finished)
            self._task = None

    @Slot()
    def progress_advance(self):
        # GUI should be updated in main thread. That's why we are calling advance method here
        if self.progress_bar:
            self.progress_bar.advance()

    def __get_input_genes(self):
        self.input_genes = []

        if self.use_attr_names:
            for variable in self.input_data.domain.attributes:
                self.input_genes.append(str(variable.attributes.get(self.gene_id_attribute, '?')))
        else:
            genes, _ = self.input_data.get_column_view(self.gene_id_column)
            self.input_genes = [str(g) for g in genes]

    def handle_custom_gene_sets(self, select_customs_flag=False):
        if self.custom_gene_set_indicator:
            if self.custom_data is not None and self.custom_gene_id_column is not None:

                if self.__check_organism_mismatch():
                    # self.gs_label_combobox.setDisabled(True)
                    self.Error.organism_mismatch()
                    self.gs_widget.update_gs_hierarchy()
                    return

                if isinstance(self.custom_gene_set_indicator, DiscreteVariable):
                    labels = self.custom_gene_set_indicator.values
                    gene_sets_names = [
                        labels[int(idx)] for idx in self.custom_data.get_column_view(self.custom_gene_set_indicator)[0]
                    ]
                else:
                    gene_sets_names, _ = self.custom_data.get_column_view(self.custom_gene_set_indicator)

                self.num_of_custom_sets = len(set(gene_sets_names))
                gene_names, _ = self.custom_data.get_column_view(self.custom_gene_id_column)
                hierarchy_title = (self.custom_data.name if self.custom_data.name else 'Custom sets',)
                try:
                    self.gs_widget.add_custom_sets(
                        gene_sets_names,
                        gene_names,
                        hierarchy_title=hierarchy_title,
                        select_customs_flag=select_customs_flag,
                    )
                except geneset.GeneSetException:
                    pass
                # self.gs_label_combobox.setDisabled(False)
            else:
                self.gs_widget.update_gs_hierarchy()

        self.update_info_box()

    def update_tree_view(self):
        self.init_gene_sets()

    def invalidate(self):
        # clear
        self.__reset_widget_state()
        self.update_info_box()

        if self.input_data is not None:
            # setup
            self.__get_input_genes()
            self.update_tree_view()

    def __check_organism_mismatch(self):
        """ Check if organisms from different inputs match.

        :return: True if there is a mismatch
        """
        if self.tax_id is not None and self.custom_tax_id is not None:
            return self.tax_id != self.custom_tax_id
        return False

    def __get_reference_genes(self):
        self.reference_genes = []

        if self.reference_attr_names:
            for variable in self.reference_data.domain.attributes:
                self.reference_genes.append(str(variable.attributes.get(self.reference_gene_id_attribute, '?')))
        else:
            genes, _ = self.reference_data.get_column_view(self.reference_gene_id_column)
            self.reference_genes = [str(g) for g in genes]

    @Inputs.custom_sets
    def handle_custom_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        self.custom_data = None
        self.custom_tax_id = None
        self.custom_use_attr_names = None
        self.custom_gene_id_attribute = None
        self.custom_gene_id_column = None
        self.feature_model.set_domain(None)

        if data:
            self.custom_data = data
            self.feature_model.set_domain(self.custom_data.domain)
            self.custom_tax_id = str(self.custom_data.attributes.get(TAX_ID, None))
            self.custom_use_attr_names = self.custom_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.custom_gene_id_attribute = self.custom_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.custom_gene_id_column = self.custom_data.attributes.get(GENE_ID_COLUMN, None)

            if self.gs_label_combobox is None:
                self.gs_label_combobox = comboBox(
                    self.custom_gs_col_box,
                    self,
                    "custom_gene_set_indicator",
                    sendSelectedValue=True,
                    model=self.feature_model,
                    callback=self.on_gene_set_indicator_changed,
                )
            self.custom_gs_col_box.show()

            if self.custom_gene_set_indicator in self.feature_model:
                index = self.feature_model.indexOf(self.custom_gene_set_indicator)
                self.custom_gene_set_indicator = self.feature_model[index]
            else:
                self.custom_gene_set_indicator = self.feature_model[0]
        else:
            self.custom_gs_col_box.hide()

        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets(select_customs_flag=self.custom_gene_set_indicator is not None)
        self.invalidate()

    @Inputs.genes
    def handle_genes_input(self, data):
        self.Error.clear()
        self.__reset_widget_state()
        # clear output
        self.Outputs.matched_genes.send(None)
        # clear input values
        self.input_genes = []
        self.input_data = None
        self.tax_id = None
        self.use_attr_names = None
        self.gene_id_attribute = None
        self.gs_widget.clear()
        self.gs_widget.clear_gene_sets()
        self.update_info_box()

        if data:
            self.input_data = data
            self.tax_id = str(self.input_data.attributes.get(TAX_ID, None))
            self.use_attr_names = self.input_data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None)
            self.gene_id_attribute = self.input_data.attributes.get(GENE_ID_ATTRIBUTE, None)
            self.gene_id_column = self.input_data.attributes.get(GENE_ID_COLUMN, None)
            self.update_info_box()

            if not (
                self.use_attr_names is not None and ((self.gene_id_attribute is None) ^ (self.gene_id_column is None))
            ):

                if self.tax_id is None:
                    self.Error.missing_annotation()
                    return

                self.Error.missing_gene_id()
                return

            elif self.tax_id is None:
                self.Error.missing_tax_id()
                return

            if self.__check_organism_mismatch():
                self.Error.organism_mismatch()
                return

            self.gs_widget.load_gene_sets(self.tax_id)

            # if input data change, we need to refresh custom sets
            if self.custom_data:
                self.gs_widget.clear_custom_sets()
                self.handle_custom_gene_sets()

            self.invalidate()

    def update_info_box(self):
        info_string = ''
        if self.input_genes:
            info_string += '{} unique gene names on input.\n'.format(len(self.input_genes))
            info_string += '{} genes on output.\n'.format(self.num_of_sel_genes)
        else:
            if self.input_data:
                if not any([self.gene_id_column, self.gene_id_attribute]):
                    info_string += 'Input data with incorrect meta data.\nUse Gene Name Matcher widget.'
            else:
                info_string += 'No data on input.\n'

        if self.custom_data:
            info_string += '{} marker genes in {} sets\n'.format(self.custom_data.X.shape[0], self.num_of_custom_sets)

        self.input_info.setText(info_string)

    def create_partial(self):
        return partial(
            self.set_items,
            self.gs_widget.gs_object,
            self.stored_gene_sets_selection,
            set(self.input_genes),
            self.callback,
        )

    def callback(self):
        if self._task.cancelled:
            raise KeyboardInterrupt()
        if self.progress_bar:
            methodinvoke(self, "progress_advance")()

    def init_gene_sets(self):
        if self._task is not None:
            self.cancel()
        assert self._task is None

        self._task = Task()
        self.init_item_model()

        # save setting on selected hierarchies
        self.stored_gene_sets_selection = self.gs_widget.get_hierarchies(only_selected=True)

        f = self.create_partial()

        progress_iterations = sum(
            (
                len(g_set)
                for hier, g_set in self.gs_widget.gs_object.map_hierarchy_to_sets().items()
                if hier in self.stored_gene_sets_selection
            )
        )

        self.progress_bar = ProgressBar(self, iterations=progress_iterations)

        self._task.future = self._executor.submit(f)

        self._task.watcher = FutureWatcher(self._task.future)
        self._task.watcher.done.connect(self._init_gene_sets_finished)

    @Slot(concurrent.futures.Future)
    def _init_gene_sets_finished(self, f):
        assert self.thread() is QThread.currentThread()
        assert threading.current_thread() == threading.main_thread()
        assert self._task is not None
        assert self._task.future is f
        assert f.done()

        self._task = None
        self.progress_bar.finish()
        self.setStatusMessage('')

        try:
            results = f.result()  # type: list
            [self.data_model.appendRow(model_item) for model_item in results]
            self.filter_proxy_model.setSourceModel(self.data_model)
            self.data_view.selectionModel().selectionChanged.connect(self.commit)
            self.filter_data_view()
            self.set_selection()
            self.update_info_box()
        except Exception as ex:
            print(ex)

    def create_filters(self):
        search_term = self.search_pattern.lower().strip().split()

        filters = [
            FilterProxyModel.Filter(
                self.TERM, Qt.DisplayRole, lambda value: all(fs in value.lower() for fs in search_term)
            )
        ]

        if self.use_min_count:
            filters.append(FilterProxyModel.Filter(self.COUNT, Qt.DisplayRole, lambda value: value >= self.min_count))

        return filters

    def filter_data_view(self):
        filter_proxy = self.filter_proxy_model  # type: FilterProxyModel
        model = filter_proxy.sourceModel()  # type: QStandardItemModel

        if isinstance(model, QStandardItemModel):

            # apply filtering rules
            filter_proxy.set_filters(self.create_filters())

            if model.rowCount() and not filter_proxy.rowCount():
                self.Warning.all_sets_filtered()
            else:
                self.Warning.clear()

    def set_selection(self):
        if len(self.selected_rows):
            view = self.data_view
            model = self.data_model

            row_model_indexes = [model.indexFromItem(model.item(i)) for i in self.selected_rows]
            proxy_rows = [self.filter_proxy_model.mapFromSource(i).row() for i in row_model_indexes]

            if model.rowCount() <= self.selected_rows[-1]:
                return

            header_count = view.header().count() - 1
            selection = QItemSelection()

            for row_index in proxy_rows:
                selection.append(
                    QItemSelectionRange(
                        self.filter_proxy_model.index(row_index, 0),
                        self.filter_proxy_model.index(row_index, header_count),
                    )
                )

            view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

    def commit(self):
        selection_model = self.data_view.selectionModel()

        if selection_model:
            selection = selection_model.selectedRows(self.COUNT)
            self.selected_rows = [self.filter_proxy_model.mapToSource(sel).row() for sel in selection]

            if selection and self.input_genes:
                genes = [model_index.data(Qt.UserRole) for model_index in selection]
                output_genes = [gene_name for gene_name in list(set.union(*genes))]
                self.num_of_sel_genes = len(output_genes)
                self.update_info_box()

                if self.use_attr_names:
                    selected = [
                        column
                        for column in self.input_data.domain.attributes
                        if self.gene_id_attribute in column.attributes
                        and str(column.attributes[self.gene_id_attribute]) in output_genes
                    ]

                    domain = Domain(selected, self.input_data.domain.class_vars, self.input_data.domain.metas)
                    new_data = self.input_data.from_table(domain, self.input_data)
                    self.Outputs.matched_genes.send(new_data)

                else:
                    # create filter from selected column for genes
                    only_known = table_filter.FilterStringList(self.gene_id_column, output_genes)
                    # apply filter to the data
                    data_table = table_filter.Values([only_known])(self.input_data)

                    self.Outputs.matched_genes.send(data_table)

    def assign_delegates(self):
        self.data_view.setItemDelegateForColumn(self.GENES, NumericalColumnDelegate(self))

        self.data_view.setItemDelegateForColumn(self.COUNT, NumericalColumnDelegate(self))

    def setup_filter_model(self):
        self.filter_proxy_model = FilterProxyModel()
        self.filter_proxy_model.setFilterKeyColumn(self.TERM)
        self.data_view.setModel(self.filter_proxy_model)

    def setup_filter_area(self):
        h_layout = QHBoxLayout()
        h_layout.setSpacing(100)
        h_widget = widgetBox(self.mainArea, orientation=h_layout)

        spin(
            h_widget,
            self,
            'min_count',
            0,
            1000,
            label='Count',
            tooltip='Minimum genes count',
            checked='use_min_count',
            callback=self.filter_data_view,
            callbackOnReturn=True,
            checkCallback=self.filter_data_view,
        )

        self.line_edit_filter = lineEdit(h_widget, self, 'search_pattern')
        self.line_edit_filter.setPlaceholderText('Filter gene sets ...')
        self.line_edit_filter.textChanged.connect(self.filter_data_view)

    def on_gene_set_indicator_changed(self):
        # self._handle_future_model()
        self.gs_widget.clear_custom_sets()
        self.handle_custom_gene_sets()
        self.invalidate()

    def setup_control_area(self):
        # Control area
        self.input_info = widgetLabel(widgetBox(self.controlArea, "Info", addSpace=True), 'No data on input.\n')
        self.custom_gs_col_box = box = vBox(self.controlArea, 'Custom Gene Set Term Column')
        box.hide()

        gene_sets_box = widgetBox(self.controlArea, "Gene Sets")
        self.gs_widget = GeneSetsSelection(gene_sets_box, self, 'stored_gene_sets_selection')
        self.gs_widget.hierarchy_tree_widget.itemClicked.connect(self.update_tree_view)

        self.commit_button = auto_commit(self.controlArea, self, "auto_commit", "&Commit", box=False)

    def setup_gui(self):
        # control area
        self.setup_control_area()

        # main area
        self.data_view = QTreeView()
        self.setup_filter_model()
        self.setup_filter_area()
        self.data_view.setAlternatingRowColors(True)
        self.data_view.sortByColumn(self.COUNT, Qt.DescendingOrder)
        self.data_view.setSortingEnabled(True)
        self.data_view.setSelectionMode(QTreeView.ExtendedSelection)
        self.data_view.setEditTriggers(QTreeView.NoEditTriggers)
        self.data_view.viewport().setMouseTracking(False)
        self.data_view.setItemDelegateForColumn(self.TERM, LinkStyledItemDelegate(self.data_view))

        self.mainArea.layout().addWidget(self.data_view)

        self.data_view.header().setSectionResizeMode(QHeaderView.ResizeToContents)
        self.assign_delegates()

    @staticmethod
    def set_items(gene_sets, sets_to_display, genes, callback):
        model_items = []
        if not genes:
            return

        for gene_set in sorted(gene_sets):
            if gene_set.hierarchy not in sets_to_display:
                continue

            callback()

            matched_set = gene_set.genes & genes
            if len(matched_set) > 0:
                category_column = QStandardItem()
                term_column = QStandardItem()
                count_column = QStandardItem()
                genes_column = QStandardItem()

                category_column.setData(", ".join(gene_set.hierarchy), Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.DisplayRole)
                term_column.setData(gene_set.name, Qt.ToolTipRole)
                term_column.setData(gene_set.link, LinkRole)
                term_column.setForeground(QColor(Qt.blue))

                count_column.setData(matched_set, Qt.UserRole)
                count_column.setData(len(matched_set), Qt.DisplayRole)

                genes_column.setData(len(gene_set.genes), Qt.DisplayRole)
                genes_column.setData(
                    set(gene_set.genes), Qt.UserRole
                )  # store genes to get then on output on selection

                model_items.append([count_column, genes_column, category_column, term_column])

        return model_items

    def init_item_model(self):
        if self.data_model:
            self.data_model.clear()
            self.setup_filter_model()
        else:
            self.data_model = QStandardItemModel()

        self.data_model.setSortRole(Qt.UserRole)
        self.data_model.setHorizontalHeaderLabels(self.DATA_HEADER_LABELS)

    def sizeHint(self):
        return QSize(1280, 960)
Ejemplo n.º 6
0
class NormalizeEditor(ScBaseEditor):
    DEFAULT_GROUP_BY = False
    DEFAULT_GROUP_VAR = None
    DEFAULT_METHOD = Normalize.CPM

    def __init__(self, parent=None, master=None, **kwargs):
        super().__init__(parent, **kwargs)
        self._group_var = self.DEFAULT_GROUP_VAR
        self._master = master
        self._master.input_data_changed.connect(self._set_model)
        self.setLayout(QVBoxLayout())

        form = QFormLayout()
        cpm_b = QRadioButton("Counts per million", checked=True)
        med_b = QRadioButton("Median")
        self.group = QButtonGroup()
        self.group.buttonClicked.connect(self._on_button_clicked)
        for i, button in enumerate([cpm_b, med_b]):
            index = index_to_enum(Normalize.Method, i).value
            self.group.addButton(button, index - 1)
            form.addRow(button)

        self.group_by_check = QCheckBox("Cell Groups: ",
                                        enabled=self.DEFAULT_GROUP_BY)
        self.group_by_check.clicked.connect(self.edited)
        self.group_by_combo = QComboBox(enabled=self.DEFAULT_GROUP_BY)
        self.group_by_model = DomainModel(order=(DomainModel.METAS,
                                                 DomainModel.CLASSES),
                                          valid_types=DiscreteVariable,
                                          alphabetical=True)
        self.group_by_combo.setModel(self.group_by_model)
        self.group_by_combo.currentIndexChanged.connect(self.changed)
        self.group_by_combo.activated.connect(self.edited)

        form.addRow(self.group_by_check, self.group_by_combo)
        self.layout().addLayout(form)

        self._set_model()

    def _set_model(self):
        data = self._master.data
        self.group_by_model.set_domain(data and data.domain)
        enable = bool(self.group_by_model)
        self.group_by_check.setChecked(False)
        self.group_by_check.setEnabled(enable)
        self.group_by_combo.setEnabled(enable)
        if self.group_by_model:
            self.group_by_combo.setCurrentIndex(0)
            if self._group_var and self._group_var in data.domain:
                index = self.group_by_model.indexOf(self._group_var)
                self.group_by_combo.setCurrentIndex(index)
        else:
            self.group_by_combo.setCurrentText(None)

    def _on_button_clicked(self):
        self.changed.emit()
        self.edited.emit()

    def setParameters(self, params):
        method = params.get("method", self.DEFAULT_METHOD)
        index = enum_to_index(Normalize.Method, method)
        self.group.buttons()[index].setChecked(True)
        self._group_var = params.get("group_var", self.DEFAULT_GROUP_VAR)
        group = bool(self._group_var and self.group_by_model)
        if group:
            index = self.group_by_model.indexOf(self._group_var)
            self.group_by_combo.setCurrentIndex(index)
        group_by = params.get("group_by", self.DEFAULT_GROUP_BY)
        self.group_by_check.setChecked(group_by and group)

    def parameters(self):
        index = self.group_by_combo.currentIndex()
        group_var = self.group_by_model[index] if index > -1 else None
        group_by = self.group_by_check.isChecked()
        method = index_to_enum(Normalize.Method, self.group.checkedId())
        return {"group_var": group_var, "group_by": group_by, "method": method}

    @staticmethod
    def createinstance(params):
        group_var = params.get("group_var")
        group_by = params.get("group_by", NormalizeEditor.DEFAULT_GROUP_BY)
        method = params.get("method", NormalizeEditor.DEFAULT_METHOD)
        return NormalizeGroups(group_var, method) \
            if group_by and group_var else NormalizeSamples(method)

    def __repr__(self):
        method = self.group.button(self.group.checkedId()).text()
        index = self.group_by_combo.currentIndex()
        group_var = self.group_by_model[index] if index > -1 else None
        group_by = self.group_by_check.isChecked()
        group_text = ", Grouped by: {}".format(group_var) if group_by else ""
        return "Method: {}".format(method) + group_text