Пример #1
0
    def default_outputs(cls, attrib):
        learner_class = attrib['LEARNER']
        replaces = []
        if issubclass(learner_class, LearnerClassification):
            model_name = 'Classifier'
        elif issubclass(learner_class, LearnerRegression):
            model_name = 'Predictor'
        else:
            model_name = 'Model'
            replaces = ['Classifier', 'Predictor']

        attrib['OUTPUT_MODEL_NAME'] = model_name

        return [widget.OutputSignal("Learner", learner_class),
                widget.OutputSignal(model_name, learner_class.__returns__,
                                    replaces=replaces)]
Пример #2
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read a data from an input file or network" \
                  "and send the data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["data", "file", "load", "read"]
    outputs = [
        widget.OutputSignal(
            "Data",
            Table,
            doc="Attribute-valued data set read from the input file.")
    ]

    want_main_area = False
    resizing_enabled = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]

    LOCAL_FILE, URL = range(2)

    settingsHandler = XlsContextHandler()

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    dlg_formats = ("All readable files ({});;".format(
        '*' + ' *'.join(FileFormat.readers.keys())) + ";;".join(
            "{} (*{})".format(f.DESCRIPTION, ' *'.join(f.EXTENSIONS))
            for f in sorted(set(FileFormat.readers.values()),
                            key=list(FileFormat.readers.values()).index)))

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""

        layout = QtGui.QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data,
                                addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, QtCore.Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(None,
                                 self,
                                 '...',
                                 callback=self.browse_file,
                                 autoDefault=False)
        file_button.setIcon(self.style().standardIcon(
            QtGui.QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(None,
                                   self,
                                   "Reload",
                                   callback=self.reload,
                                   autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QtGui.QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(None,
                                        self,
                                        "xls_sheet",
                                        callback=self.load_data,
                                        sendSelectedValue=True)
        self.sheet_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QtGui.QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(self.sheet_label,
                                          QtCore.Qt.AlignLeft)
        self.sheet_box.layout().addWidget(self.sheet_combo,
                                          QtCore.Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, QtCore.Qt.AlignVCenter)

        self.url_combo = url_combo = QtGui.QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        url_combo.setMaximumWidth(500)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)

        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.hBox(self.controlArea)
        gui.button(box,
                   self,
                   "Browse documentation data sets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)
        box.layout().addWidget(self.report_button)
        self.report_button.setFixedWidth(170)

        # Set word wrap, so long warnings won't expand the widget
        self.warnings.setWordWrap(True)
        self.warnings.setSizePolicy(Policy.Ignored, Policy.MinimumExpanding)

        self.set_file_list()
        if self.last_path() is not None:
            self.fill_sheet_combo(self.last_path())
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)
        QtCore.QTimer.singleShot(0, self.load_data)

    def reload(self):
        if self.recent_paths:
            basename = self.file_combo.currentText()
            path = self.recent_paths[0]
            if basename in [path.relpath, path.basename]:
                self.source = self.LOCAL_FILE
                if self.is_multisheet_excel(path.abspath):
                    self.fill_sheet_combo(path.abspath)
                return self.load_data()
        self.select_file(len(self.recent_paths) + 1)

    def select_file(self, n):
        if n < len(self.recent_paths):
            super().select_file(n)
            self.fill_sheet_combo(self.last_path())
        # TODO: This is weird. Has it remained here from "Browse data sets"
        # or from when this combo was editable? A `n` this large can come from
        # `reload`, but ... how?!
        elif n:
            path = self.file_combo.currentText()
            if os.path.exists(path):
                self.add_path(path)
            else:
                self.info.setText('Data was not loaded:')
                self.warnings.setText("File {} does not exist".format(path))
                self.file_combo.removeItem(n)
                self.file_combo.lineEdit().setText(path)
                return

        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def _url_set(self):
        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QtGui.QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation data sets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        filename = QtGui.QFileDialog.getOpenFileName(self,
                                                     'Open Orange Data File',
                                                     start_file,
                                                     self.dlg_formats)
        if not filename:
            return

        self.add_path(filename)
        self.source = self.LOCAL_FILE
        self.fill_sheet_combo(filename)
        self.load_data()

    def fill_sheet_combo(self, path):
        if os.path.exists(path) and self.is_multisheet_excel(path):
            self.closeContext()
            self.sheet_combo.clear()
            self.sheet_box.show()
            book = open_workbook(path)
            sheet_names = [
                str(book.sheet_by_index(i).name) for i in range(book.nsheets)
            ]
            self.sheet_combo.addItems(sheet_names)
            self.openContext(path, sheet_names)
        else:
            self.sheet_box.hide()

    @staticmethod
    def is_multisheet_excel(fn):
        try:
            return open_workbook(fn).nsheets > 1
        except XLRDError:
            return False

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        def load(method, fn):
            with catch_warnings(record=True) as warnings:
                data = method(fn)
                self.warning(33,
                             warnings[-1].message.args[0] if warnings else '')
            return data, fn

        def load_from_file():
            fn = self.last_path()
            if not fn:
                return None, ""
            if not os.path.exists(fn):
                dir_name, basename = os.path.split(fn)
                if os.path.exists(os.path.join(".", basename)):
                    fn = os.path.join(".", basename)
                    self.information(
                        "Loading '{}' from the current directory.".format(
                            basename))
            if self.is_multisheet_excel(fn):
                data = ExcelFormat.read_file(fn + ':' + self.xls_sheet)
                if data:
                    return data, fn

            try:
                return load(Table.from_file, fn)
            except Exception as exc:
                self.warnings.setText(str(exc))
                # Let us not remove from recent files: user may fix them
                raise

        def load_from_network():
            combo = self.url_combo
            model = combo.model()
            # combo.currentText does not work when the widget is initialized
            url = model.data(model.index(combo.currentIndex()),
                             QtCore.Qt.EditRole)
            if not url:
                return None, ""
            elif "://" not in url:
                url = "http://" + url
            try:
                data, url = load(Table.from_url, url)
            except:
                self.warnings.setText(
                    "URL '{}' does not contain valid data".format(url))
                # Don't remove from recent_urls:
                # resource may reappear, or the user mistyped it
                # and would like to retrieve it from history and fix it.
                raise
            combo.clearFocus()
            if "://docs.google.com/spreadsheets" in url:
                model.add_name(url, data.name)
                self.url = \
                    "{} from {}".format(data.name.replace("- Sheet1", ""), url)
                combo.lineEdit().setPlaceholderText(self.url)
                return data, data.name
            else:
                self.url = url
                return data, url

        self.warning()
        self.information()

        try:
            loader = [load_from_file, load_from_network][self.source]
            self.data, self.loaded_file = loader()
        except:
            self.info.setText("Data was not loaded:")
            self.data = None
            self.loaded_file = ""
            return
        else:
            self.warnings.setText("")

        data = self.data
        if data is None:
            self.send("Data", None)
            self.info.setText("No data loaded")
            return

        domain = data.domain
        text = "{} instance(s), {} feature(s), {} meta attribute(s)".format(
            len(data), len(domain.attributes), len(domain.metas))
        if domain.has_continuous_class:
            text += "\nRegression; numerical class."
        elif domain.has_discrete_class:
            text += "\nClassification; discrete class with {} values.".format(
                len(domain.class_var.values))
        elif data.domain.class_vars:
            text += "\nMulti-target; {} target variables.".format(
                len(data.domain.class_vars))
        else:
            text += "\nData has no target variable."
        if 'Timestamp' in data.domain:
            # Google Forms uses this header to timestamp responses
            text += '\n\nFirst entry: {}\nLast entry: {}'.format(
                data[0, 'Timestamp'], data[-1, 'Timestamp'])
        self.info.setText(text)

        add_origin(data, self.loaded_file)
        self.send("Data", data)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~/" + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += " ({})".format(self.sheet_combo.currentText())
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)
Пример #3
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read data from an input file or network " \
                  "and send a data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["data", "file", "load", "read"]
    outputs = [
        widget.OutputSignal(
            "Data",
            Table,
            doc="Attribute-valued data set read from the input file.")
    ]

    want_main_area = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]
    SIZE_LIMIT = 1e7
    LOCAL_FILE, URL = range(2)

    settingsHandler = PerfectDomainContextHandler()

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
        RecentPath("", "sample-datasets", "heart_disease.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    variables = ContextSetting([])

    dlg_formats = ("All readable files ({});;".format(
        '*' + ' *'.join(FileFormat.readers.keys())) + ";;".join(
            "{} (*{})".format(f.DESCRIPTION, ' *'.join(f.EXTENSIONS))
            for f in sorted(set(FileFormat.readers.values()),
                            key=list(FileFormat.readers.values()).index)))

    domain_editor = SettingProvider(DomainEditor)

    class Warning(widget.OWWidget.Warning):
        file_too_big = widget.Msg(
            "The file is too large to load automatically."
            " Press Reload to load.")

    class Error(widget.OWWidget.Error):
        file_not_found = widget.Msg("File not found.")

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self.reader = None

        layout = QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data,
                                addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File:", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(None,
                                 self,
                                 '...',
                                 callback=self.browse_file,
                                 autoDefault=False)
        file_button.setIcon(self.style().standardIcon(QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(None,
                                   self,
                                   "Reload",
                                   callback=self.load_data,
                                   autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(
            None,
            self,
            "xls_sheet",
            callback=self.select_sheet,
            sendSelectedValue=True,
        )
        self.sheet_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(self.sheet_label, Qt.AlignLeft)
        self.sheet_box.layout().addWidget(self.sheet_combo, Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL:", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, Qt.AlignVCenter)

        self.url_combo = url_combo = QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setLineEdit(LineEditSelectOnFocus())
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)

        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, "Columns (Double click to edit)")
        self.domain_editor = DomainEditor(self)
        self.editor_model = self.domain_editor.model()
        box.layout().addWidget(self.domain_editor)

        box = gui.hBox(self.controlArea)
        gui.button(box,
                   self,
                   "Browse documentation data sets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)
        box.layout().addWidget(self.report_button)
        self.report_button.setFixedWidth(170)

        self.apply_button = gui.button(box,
                                       self,
                                       "Apply",
                                       callback=self.apply_domain_edit)
        self.apply_button.setEnabled(False)
        self.apply_button.setFixedWidth(170)
        self.editor_model.dataChanged.connect(
            lambda: self.apply_button.setEnabled(True))

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)

        self.setAcceptDrops(True)

        if self.source == self.LOCAL_FILE:
            last_path = self.last_path()
            if last_path and os.path.exists(last_path) and \
                    os.path.getsize(last_path) > self.SIZE_LIMIT:
                self.Warning.file_too_big()
                return

        QTimer.singleShot(0, self.load_data)

    def sizeHint(self):
        return QSize(600, 550)

    def select_file(self, n):
        assert n < len(self.recent_paths)
        super().select_file(n)
        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def select_sheet(self):
        self.recent_paths[0].sheet = self.sheet_combo.currentText()
        self.load_data()

    def _url_set(self):
        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation data sets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        filename, _ = QFileDialog.getOpenFileName(self,
                                                  'Open Orange Data File',
                                                  start_file, self.dlg_formats)
        if not filename:
            return
        self.add_path(filename)
        self.source = self.LOCAL_FILE
        self.load_data()

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        # We need to catch any exception type since anything can happen in
        # file readers
        # pylint: disable=broad-except
        self.closeContext()
        self.domain_editor.set_domain(None)
        self.apply_button.setEnabled(False)
        self.clear_messages()
        self.set_file_list()
        if self.last_path() and not os.path.exists(self.last_path()):
            self.Error.file_not_found()
            self.send("Data", None)
            self.info.setText("No data.")
            return

        error = None
        try:
            self.reader = self._get_reader()
            if self.reader is None:
                self.data = None
                self.send("Data", None)
                self.info.setText("No data.")
                self.sheet_box.hide()
                return
        except Exception as ex:
            error = ex

        if not error:
            self._update_sheet_combo()
            with catch_warnings(record=True) as warnings:
                try:
                    data = self.reader.read()
                except Exception as ex:
                    log.exception(ex)
                    error = ex
                self.warning(warnings[-1].message.args[0] if warnings else '')

        if error:
            self.data = None
            self.send("Data", None)
            self.info.setText("An error occurred:\n{}".format(error))
            self.sheet_box.hide()
            return

        self.info.setText(self._describe(data))

        self.loaded_file = self.last_path()
        add_origin(data, self.loaded_file)
        self.data = data
        self.openContext(data.domain)
        self.apply_domain_edit()  # sends data

    def _get_reader(self):
        """

        Returns
        -------
        FileFormat
        """
        if self.source == self.LOCAL_FILE:
            reader = FileFormat.get_reader(self.last_path())
            if self.recent_paths and self.recent_paths[0].sheet:
                reader.select_sheet(self.recent_paths[0].sheet)
            return reader
        elif self.source == self.URL:
            url = self.url_combo.currentText().strip()
            if url:
                return UrlReader(url)

    def _update_sheet_combo(self):
        if len(self.reader.sheets) < 2:
            self.sheet_box.hide()
            self.reader.select_sheet(None)
            return

        self.sheet_combo.clear()
        self.sheet_combo.addItems(self.reader.sheets)
        self._select_active_sheet()
        self.sheet_box.show()

    def _select_active_sheet(self):
        if self.reader.sheet:
            try:
                idx = self.reader.sheets.index(self.reader.sheet)
                self.sheet_combo.setCurrentIndex(idx)
            except ValueError:
                # Requested sheet does not exist in this file
                self.reader.select_sheet(None)
        else:
            self.sheet_combo.setCurrentIndex(0)

    def _describe(self, table):
        domain = table.domain
        text = ""

        attrs = getattr(table, "attributes", {})
        descs = [
            attrs[desc] for desc in ("Name", "Description") if desc in attrs
        ]
        if len(descs) == 2:
            descs[0] = "<b>{}</b>".format(descs[0])
        if descs:
            text += "<p>{}</p>".format("<br/>".join(descs))

        text += "<p>{} instance(s), {} feature(s), {} meta attribute(s)".\
            format(len(table), len(domain.attributes), len(domain.metas))
        if domain.has_continuous_class:
            text += "<br/>Regression; numerical class."
        elif domain.has_discrete_class:
            text += "<br/>Classification; discrete class with {} values.".\
                format(len(domain.class_var.values))
        elif table.domain.class_vars:
            text += "<br/>Multi-target; {} target variables.".format(
                len(table.domain.class_vars))
        else:
            text += "<br/>Data has no target variable."
        text += "</p>"

        if 'Timestamp' in table.domain:
            # Google Forms uses this header to timestamp responses
            text += '<p>First entry: {}<br/>Last entry: {}</p>'.format(
                table[0, 'Timestamp'], table[-1, 'Timestamp'])
        return text

    def storeSpecificSettings(self):
        self.current_context.modified_variables = self.variables[:]

    def retrieveSpecificSettings(self):
        if hasattr(self.current_context, "modified_variables"):
            self.variables[:] = self.current_context.modified_variables

    def apply_domain_edit(self):
        if self.data is not None:
            domain, cols = self.domain_editor.get_domain(
                self.data.domain, self.data)
            X, y, m = cols
            X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
            y = np.array(y).T if len(y) else None
            dtpe = object if any(
                isinstance(m, StringVariable) for m in domain.metas) else float
            m = np.array(m, dtype=dtpe).T if len(m) else None
            table = Table.from_numpy(domain, X, y, m, self.data.W)
            table.name = self.data.name
            table.ids = np.array(self.data.ids)
            table.attributes = getattr(self.data, 'attributes', {})
        else:
            table = self.data

        self.send("Data", table)
        self.apply_button.setEnabled(False)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~" + os.path.sep + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += " ({})".format(self.sheet_combo.currentText())
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    def dragEnterEvent(self, event):
        """Accept drops of valid file urls"""
        urls = event.mimeData().urls()
        if urls:
            try:
                FileFormat.get_reader(
                    OSX_NSURL_toLocalFile(urls[0]) or urls[0].toLocalFile())
                event.acceptProposedAction()
            except IOError:
                pass

    def dropEvent(self, event):
        """Handle file drops"""
        urls = event.mimeData().urls()
        if urls:
            self.add_path(
                OSX_NSURL_toLocalFile(urls[0])
                or urls[0].toLocalFile())  # add first file
            self.source = self.LOCAL_FILE
            self.load_data()
Пример #4
0
class OWTreeGraph(OWTreeViewer2D):
    """Graphical visualization of tree models"""

    name = "Tree Viewer"
    icon = "icons/TreeViewer.svg"
    priority = 35
    inputs = [
        widget.InputSignal(
            "Tree",
            TreeModel,
            "ctree",
            # Had different input names before merging from
            # Classification/Regression tree variants
            replaces=["Classification Tree", "Regression Tree"])
    ]
    outputs = [
        widget.OutputSignal(
            "Selected Data",
            Table,
            widget.Default,
            id="selected-data",
        ),
        widget.OutputSignal(ANNOTATED_DATA_SIGNAL_NAME,
                            Table,
                            id="annotated-data")
    ]

    settingsHandler = ClassValuesContextHandler()
    target_class_index = ContextSetting(0)
    regression_colors = Setting(0)

    replaces = [
        "Orange.widgets.classify.owclassificationtreegraph.OWClassificationTreeGraph",
        "Orange.widgets.classify.owregressiontreegraph.OWRegressionTreeGraph"
    ]

    COL_OPTIONS = ["Default", "Number of instances", "Mean value", "Variance"]
    COL_DEFAULT, COL_INSTANCE, COL_MEAN, COL_VARIANCE = range(4)

    def __init__(self):
        super().__init__()
        self.domain = None
        self.dataset = None
        self.clf_dataset = None
        self.tree_adapter = None

        self.color_label = QLabel("Target class: ")
        combo = self.color_combo = gui.OrangeComboBox()
        combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        combo.setSizeAdjustPolicy(
            QComboBox.AdjustToMinimumContentsLengthWithIcon)
        combo.setMinimumContentsLength(8)
        combo.activated[int].connect(self.color_changed)
        self.display_box.layout().addRow(self.color_label, combo)

    def set_node_info(self):
        """Set the content of the node"""
        for node in self.scene.nodes():
            node.set_rect(QRectF())
            self.update_node_info(node)
        w = max([n.rect().width() for n in self.scene.nodes()] + [0])
        if w > self.max_node_width:
            w = self.max_node_width
        for node in self.scene.nodes():
            rect = node.rect()
            node.set_rect(QRectF(rect.x(), rect.y(), w, rect.height()))
        self.scene.fix_pos(self.root_node, 10, 10)

    def _update_node_info_attr_name(self, node, text):
        attr = self.tree_adapter.attribute(node.node_inst)
        if attr is not None:
            text += "<hr/>{}".format(attr.name)
        return text

    def activate_loaded_settings(self):
        if not self.model:
            return
        super().activate_loaded_settings()
        if self.domain.class_var.is_discrete:
            self.color_combo.setCurrentIndex(self.target_class_index)
            self.toggle_node_color_cls()
        else:
            self.color_combo.setCurrentIndex(self.regression_colors)
            self.toggle_node_color_reg()
        self.set_node_info()

    def color_changed(self, i):
        if self.domain.class_var.is_discrete:
            self.target_class_index = i
            self.toggle_node_color_cls()
            self.set_node_info()
        else:
            self.regression_colors = i
            self.toggle_node_color_reg()

    def toggle_node_size(self):
        self.set_node_info()
        self.scene.update()
        self.scene_view.repaint()

    def toggle_color_cls(self):
        self.toggle_node_color_cls()
        self.set_node_info()
        self.scene.update()

    def toggle_color_reg(self):
        self.toggle_node_color_reg()
        self.set_node_info()
        self.scene.update()

    def ctree(self, model=None):
        """Input signal handler"""
        self.clear_scene()
        self.color_combo.clear()
        self.closeContext()
        self.model = model
        if model is None:
            self.info.setText('No tree.')
            self.root_node = None
            self.dataset = None
            self.tree_adapter = None
        else:
            self.tree_adapter = self._get_tree_adapter(model)
            self.domain = model.domain
            self.dataset = model.instances
            if self.dataset is not None and self.dataset.domain != self.domain:
                self.clf_dataset = Table.from_table(model.domain, self.dataset)
            else:
                self.clf_dataset = self.dataset
            class_var = self.domain.class_var
            if class_var.is_discrete:
                self.scene.colors = [QColor(*col) for col in class_var.colors]
                self.color_label.setText("Target class: ")
                self.color_combo.addItem("None")
                self.color_combo.addItems(self.domain.class_vars[0].values)
                self.color_combo.setCurrentIndex(self.target_class_index)
            else:
                self.scene.colors = \
                    ContinuousPaletteGenerator(*model.domain.class_var.colors)
                self.color_label.setText("Color by: ")
                self.color_combo.addItems(self.COL_OPTIONS)
                self.color_combo.setCurrentIndex(self.regression_colors)
            self.openContext(self.domain.class_var)
            # self.root_node = self.walkcreate(model.root, None)
            self.root_node = self.walkcreate(self.tree_adapter.root)
            self.info.setText('{} nodes, {} leaves'.format(
                self.tree_adapter.num_nodes,
                len(self.tree_adapter.leaves(self.tree_adapter.root))))
        self.setup_scene()
        self.send("Selected Data", None)
        self.send(ANNOTATED_DATA_SIGNAL_NAME,
                  create_annotated_table(self.dataset, []))

    def walkcreate(self, node, parent=None):
        """Create a structure of tree nodes from the given model"""
        node_obj = TreeNode(self.tree_adapter, node, parent)
        self.scene.addItem(node_obj)
        if parent:
            edge = GraphicsEdge(node1=parent, node2=node_obj)
            self.scene.addItem(edge)
            parent.graph_add_edge(edge)
        for child_inst in self.tree_adapter.children(node):
            if child_inst is not None:
                self.walkcreate(child_inst, node_obj)
        return node_obj

    def node_tooltip(self, node):
        return "<br>".join(
            to_html(str(rule))
            for rule in self.tree_adapter.rules(node.node_inst))

    def update_selection(self):
        if self.model is None:
            return
        nodes = [
            item.node_inst for item in self.scene.selectedItems()
            if isinstance(item, TreeNode)
        ]
        data = self.tree_adapter.get_instances_in_nodes(
            self.clf_dataset, nodes)
        self.send("Selected Data", data)
        self.send(
            ANNOTATED_DATA_SIGNAL_NAME,
            create_annotated_table(self.dataset,
                                   self.tree_adapter.get_indices(nodes)))

    def send_report(self):
        if not self.model:
            return
        items = [
            ("Tree size", self.info.text()),
            (
                "Edge widths",
                ("Fixed", "Relative to root", "Relative to parent")[
                    # pylint: disable=invalid-sequence-index
                    self.line_width_method])
        ]
        if self.domain.class_var.is_discrete:
            items.append(("Target class", self.color_combo.currentText()))
        elif self.regression_colors != self.COL_DEFAULT:
            items.append(
                ("Color by", self.COL_OPTIONS[self.regression_colors]))
        self.report_items(items)
        self.report_plot(self.scene)

    def update_node_info(self, node):
        if self.domain.class_var.is_discrete:
            self.update_node_info_cls(node)
        else:
            self.update_node_info_reg(node)

    def update_node_info_cls(self, node):
        """Update the printed contents of the node for classification trees"""
        node_inst = node.node_inst
        distr = self.tree_adapter.get_distribution(node_inst)[0]
        total = self.tree_adapter.num_samples(node_inst)
        distr = distr / np.sum(distr)
        if self.target_class_index:
            tabs = distr[self.target_class_index - 1]
            text = ""
        else:
            modus = np.argmax(distr)
            tabs = distr[modus]
            text = self.domain.class_vars[0].values[int(modus)] + "<br/>"
        if tabs > 0.999:
            text += "100%, {}/{}".format(total, total)
        else:
            text += "{:2.1f}%, {}/{}".format(100 * tabs, int(total * tabs),
                                             total)

        text = self._update_node_info_attr_name(node, text)
        node.setHtml('<p style="line-height: 120%; margin-bottom: 0">'
                     '{}</p>'.format(text))

    def update_node_info_reg(self, node):
        """Update the printed contents of the node for regression trees"""
        node_inst = node.node_inst
        mean, var = self.tree_adapter.get_distribution(node_inst)[0]
        insts = self.tree_adapter.num_samples(node_inst)
        text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
        text += "{} instances".format(insts)
        text = self._update_node_info_attr_name(node, text)
        node.setHtml(
            '<p style="line-height: 120%; margin-bottom: 0">{}</p>'.format(
                text))

    def toggle_node_color_cls(self):
        """Update the node color for classification trees"""
        colors = self.scene.colors
        for node in self.scene.nodes():
            distr = node.tree_adapter.get_distribution(node.node_inst)[0]
            total = sum(distr)
            if self.target_class_index:
                p = distr[self.target_class_index - 1] / total
                color = colors[self.target_class_index - 1].lighter(200 -
                                                                    100 * p)
            else:
                modus = np.argmax(distr)
                p = distr[modus] / (total or 1)
                color = colors[int(modus)].lighter(300 - 200 * p)
            node.backgroundBrush = QBrush(color)
        self.scene.update()

    def toggle_node_color_reg(self):
        """Update the node color for regression trees"""
        def_color = QColor(192, 192, 255)
        if self.regression_colors == self.COL_DEFAULT:
            brush = QBrush(def_color.lighter(100))
            for node in self.scene.nodes():
                node.backgroundBrush = brush
        elif self.regression_colors == self.COL_INSTANCE:
            max_insts = len(
                self.tree_adapter.get_instances_in_nodes(
                    self.dataset, [self.tree_adapter.root]))
            for node in self.scene.nodes():
                node_insts = len(
                    self.tree_adapter.get_instances_in_nodes(
                        self.dataset, [node.node_inst]))
                node.backgroundBrush = QBrush(
                    def_color.lighter(120 - 20 * node_insts / max_insts))
        elif self.regression_colors == self.COL_MEAN:
            minv = np.nanmin(self.dataset.Y)
            maxv = np.nanmax(self.dataset.Y)
            fact = 1 / (maxv - minv) if minv != maxv else 1
            colors = self.scene.colors
            for node in self.scene.nodes():
                node_mean = self.tree_adapter.get_distribution(
                    node.node_inst)[0][0]
                node.backgroundBrush = QBrush(colors[fact *
                                                     (node_mean - minv)])
        else:
            nodes = list(self.scene.nodes())
            variances = [
                self.tree_adapter.get_distribution(node.node_inst)[0][1]
                for node in nodes
            ]
            max_var = max(variances)
            for node, var in zip(nodes, variances):
                node.backgroundBrush = QBrush(
                    def_color.lighter(120 - 20 * var / max_var))
        self.scene.update()

    def _get_tree_adapter(self, model):
        if isinstance(model, SklModel):
            return SklTreeAdapter(model)
        return TreeAdapter(model)
Пример #5
0
class OWFile(widget.OWWidget, RecentPathsWComboMixin):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read a data from an input file or network " \
                  "and send the data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["data", "file", "load", "read"]
    outputs = [widget.OutputSignal(
        "Data", Table,
        doc="Attribute-valued data set read from the input file.")]

    want_main_area = False

    SEARCH_PATHS = [("sample-datasets", get_sample_datasets_dir())]

    LOCAL_FILE, URL = range(2)

    settingsHandler = PerfectDomainContextHandler()

    # Overload RecentPathsWidgetMixin.recent_paths to set defaults
    recent_paths = Setting([
        RecentPath("", "sample-datasets", "iris.tab"),
        RecentPath("", "sample-datasets", "titanic.tab"),
        RecentPath("", "sample-datasets", "housing.tab"),
    ])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    xls_sheet = ContextSetting("")
    sheet_names = Setting({})
    url = Setting("")

    variables = ContextSetting([])

    dlg_formats = (
        "All readable files ({});;".format(
            '*' + ' *'.join(FileFormat.readers.keys())) +
        ";;".join("{} (*{})".format(f.DESCRIPTION, ' *'.join(f.EXTENSIONS))
                  for f in sorted(set(FileFormat.readers.values()),
                                  key=list(FileFormat.readers.values()).index)))

    def __init__(self):
        super().__init__()
        RecentPathsWComboMixin.__init__(self)
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self.reader = None

        layout = QtGui.QGridLayout()
        gui.widgetBox(self.controlArea, margin=0, orientation=layout)
        vbox = gui.radioButtons(None, self, "source", box=True, addSpace=True,
                                callback=self.load_data, addToLayout=False)

        rb_button = gui.appendRadioButton(vbox, "File:", addToLayout=False)
        layout.addWidget(rb_button, 0, 0, QtCore.Qt.AlignVCenter)

        box = gui.hBox(None, addToLayout=False, margin=0)
        box.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        self.file_combo.activated[int].connect(self.select_file)
        box.layout().addWidget(self.file_combo)
        layout.addWidget(box, 0, 1)

        file_button = gui.button(
            None, self, '...', callback=self.browse_file, autoDefault=False)
        file_button.setIcon(self.style().standardIcon(
            QtGui.QStyle.SP_DirOpenIcon))
        file_button.setSizePolicy(Policy.Maximum, Policy.Fixed)
        layout.addWidget(file_button, 0, 2)

        reload_button = gui.button(
            None, self, "Reload", callback=self.load_data, autoDefault=False)
        reload_button.setIcon(self.style().standardIcon(
            QtGui.QStyle.SP_BrowserReload))
        reload_button.setSizePolicy(Policy.Fixed, Policy.Fixed)
        layout.addWidget(reload_button, 0, 3)

        self.sheet_box = gui.hBox(None, addToLayout=False, margin=0)
        self.sheet_combo = gui.comboBox(None, self, "xls_sheet",
                                        callback=self.select_sheet,
                                        sendSelectedValue=True)
        self.sheet_combo.setSizePolicy(
            Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_label = QtGui.QLabel()
        self.sheet_label.setText('Sheet')
        self.sheet_label.setSizePolicy(
            Policy.MinimumExpanding, Policy.Fixed)
        self.sheet_box.layout().addWidget(
            self.sheet_label, QtCore.Qt.AlignLeft)
        self.sheet_box.layout().addWidget(
            self.sheet_combo, QtCore.Qt.AlignVCenter)
        layout.addWidget(self.sheet_box, 2, 1)
        self.sheet_box.hide()

        rb_button = gui.appendRadioButton(vbox, "URL:", addToLayout=False)
        layout.addWidget(rb_button, 3, 0, QtCore.Qt.AlignVCenter)

        self.url_combo = url_combo = QtGui.QComboBox()
        url_model = NamedURLModel(self.sheet_names)
        url_model.wrap(self.recent_urls)
        url_combo.setModel(url_model)
        url_combo.setSizePolicy(Policy.MinimumExpanding, Policy.Fixed)
        url_combo.setEditable(True)
        url_combo.setInsertPolicy(url_combo.InsertAtTop)
        url_edit = url_combo.lineEdit()
        l, t, r, b = url_edit.getTextMargins()
        url_edit.setTextMargins(l + 5, t, r, b)
        layout.addWidget(url_combo, 3, 1, 3, 3)
        url_combo.activated.connect(self._url_set)

        box = gui.vBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, "Columns (Double click to edit)")
        domain_editor = DomainEditor(self.variables)
        self.editor_model = domain_editor.model()
        box.layout().addWidget(domain_editor)

        box = gui.hBox(self.controlArea)
        gui.button(
            box, self, "Browse documentation data sets",
            callback=lambda: self.browse_file(True), autoDefault=False)
        gui.rubber(box)
        box.layout().addWidget(self.report_button)
        self.report_button.setFixedWidth(170)

        self.apply_button = gui.button(
            box, self, "Apply", callback=self.apply_domain_edit)
        self.apply_button.hide()
        self.apply_button.setFixedWidth(170)
        self.editor_model.dataChanged.connect(self.apply_button.show)

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)
        QtCore.QTimer.singleShot(0, self.load_data)

        self.setAcceptDrops(True)

    def sizeHint(self):
        return QtCore.QSize(600, 550)

    def select_file(self, n):
        assert n < len(self.recent_paths)
        super().select_file(n)
        if self.recent_paths:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def select_sheet(self):
        self.recent_paths[0].sheet = self.sheet_combo.currentText()
        self.load_data()

    def _url_set(self):
        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            start_file = get_sample_datasets_dir()
            if not os.path.exists(start_file):
                QtGui.QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with documentation data sets")
                return
        else:
            start_file = self.last_path() or os.path.expanduser("~/")

        filename = QtGui.QFileDialog.getOpenFileName(
            self, 'Open Orange Data File', start_file, self.dlg_formats)
        if not filename:
            return
        self.loaded_file = filename
        self.add_path(filename)
        self.source = self.LOCAL_FILE
        self.load_data()

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        self.reader = self._get_reader()
        self._update_sheet_combo()

        errors = []
        with catch_warnings(record=True) as warnings:
            try:
                data = self.reader.read()
            except Exception as ex:
                errors.append("An error occurred:")
                errors.append(str(ex))
                data = None
                self.editor_model.reset()
            self.warning(warnings[-1].message.args[0] if warnings else '')

        if data is None:
            self.send("Data", None)
            self.info.setText("\n".join(errors))
            return

        self.info.setText(self._describe(data))

        add_origin(data, self.loaded_file or self.last_path())
        self.send("Data", data)
        self.editor_model.set_domain(data.domain)
        self.data = data

    def _get_reader(self):
        """

        Returns
        -------
        FileFormat
        """
        if self.source == self.LOCAL_FILE:
            reader = FileFormat.get_reader(self.last_path())
            if self.recent_paths and self.recent_paths[0].sheet:
                reader.select_sheet(self.recent_paths[0].sheet)
            return reader
        elif self.source == self.URL:
            return UrlReader(self.url_combo.currentText())

    def _update_sheet_combo(self):
        if len(self.reader.sheets) < 2:
            self.sheet_box.hide()
            self.reader.select_sheet(None)
            return

        self.sheet_combo.clear()
        self.sheet_combo.addItems(self.reader.sheets)
        self._select_active_sheet()
        self.sheet_box.show()

    def _select_active_sheet(self):
        if self.reader.sheet:
            try:
                idx = self.reader.sheets.index(self.reader.sheet)
                self.sheet_combo.setCurrentIndex(idx)
            except ValueError:
                # Requested sheet does not exist in this file
                self.reader.select_sheet(None)
        else:
            self.sheet_combo.setCurrentIndex(0)

    def _describe(self, table):
        domain = table.domain
        text = "{} instance(s), {} feature(s), {} meta attribute(s)".format(
            len(table), len(domain.attributes), len(domain.metas))
        if domain.has_continuous_class:
            text += "\nRegression; numerical class."
        elif domain.has_discrete_class:
            text += "\nClassification; discrete class with {} values.".format(
                len(domain.class_var.values))
        elif table.domain.class_vars:
            text += "\nMulti-target; {} target variables.".format(
                len(table.domain.class_vars))
        else:
            text += "\nData has no target variable."
        if 'Timestamp' in table.domain:
            # Google Forms uses this header to timestamp responses
            text += '\n\nFirst entry: {}\nLast entry: {}'.format(
                table[0, 'Timestamp'], table[-1, 'Timestamp'])
        return text

    def storeSpecificSettings(self):
        self.current_context.modified_variables = self.variables[:]

    def retrieveSpecificSettings(self):
        if hasattr(self.current_context, "modified_variables"):
            self.variables[:] = self.current_context.modified_variables

    def apply_domain_edit(self):
        attributes = []
        class_vars = []
        metas = []
        places = [attributes, class_vars, metas]
        X, y, m = [], [], []
        cols = [X, y, m]  # Xcols, Ycols, Mcols

        def is_missing(x):
            return str(x) in ("nan", "")

        for column, (name, tpe, place, vals, is_con), (orig_var, orig_plc) in \
            zip(count(), self.editor_model.variables,
                chain([(at, 0) for at in self.data.domain.attributes],
                      [(cl, 1) for cl in self.data.domain.class_vars],
                      [(mt, 2) for mt in self.data.domain.metas])):
            if place == 3:
                continue
            if orig_plc == 2:
                col_data = list(chain(*self.data[:, orig_var].metas))
            else:
                col_data = list(chain(*self.data[:, orig_var]))
            if name == orig_var.name and tpe == type(orig_var):
                var = orig_var
            elif tpe == DiscreteVariable:
                values = list(str(i) for i in set(col_data) if not is_missing(i))
                var = tpe(name, values)
                col_data = [np.nan if is_missing(x) else values.index(str(x))
                            for x in col_data]
            elif tpe == StringVariable and type(orig_var) == DiscreteVariable:
                var = tpe(name)
                col_data = [orig_var.repr_val(x) if not np.isnan(x) else ""
                            for x in col_data]
            else:
                var = tpe(name)
            places[place].append(var)
            cols[place].append(col_data)
        domain = Domain(attributes, class_vars, metas)
        X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
        y = np.array(y).T if len(y) else None
        dtpe = object if any(isinstance(m, StringVariable)
                             for m in domain.metas) else float
        m = np.array(m, dtype=dtpe).T if len(m) else None
        table = Table.from_numpy(domain, X, y, m, self.data.W)
        self.send("Data", table)
        self.apply_button.hide()

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~/" + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            if self.sheet_combo.isVisible():
                name += " ({})".format(self.sheet_combo.currentText())
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("Resource", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    def dragEnterEvent(self, event):
        """Accept drops of valid file urls"""
        urls = event.mimeData().urls()
        if urls:
            try:
                FileFormat.get_reader(OSX_NSURL_toLocalFile(urls[0]) or
                                      urls[0].toLocalFile())
                event.acceptProposedAction()
            except IOError:
                pass

    def dropEvent(self, event):
        """Handle file drops"""
        urls = event.mimeData().urls()
        if urls:
            self.add_path(OSX_NSURL_toLocalFile(urls[0]) or
                          urls[0].toLocalFile())  # add first file
            self.source = self.LOCAL_FILE
            self.load_data()
Пример #6
0
class OWCSVFileImport(widget.OWWidget):
    name = "CSV File Import"
    description = "Import a data table from a CSV formatted file."
    icon = "icons/CSVFile.svg"
    priority = 11
    category = "Data"
    keywords = ["file", "load", "read", "open", "csv"]

    outputs = [
        widget.OutputSignal(
            name="Data",
            type=Orange.data.Table,
            doc="Loaded data set."),
        widget.OutputSignal(
            name="Data Frame",
            type=pd.DataFrame,
            doc=""
        )
    ]

    class Error(widget.OWWidget.Error):
        error = widget.Msg(
            "Unexpected error"
        )
        encoding_error = widget.Msg(
            "Encoding error\n"
            "The file might be encoded in an unsupported encoding or it "
            "might be binary"
        )

    #: Paths and options of files accessed in a 'session'
    _session_items = settings.Setting(
        [], schema_only=True)  # type: List[Tuple[str, dict]]

    #: Saved dialog state (last directory and selected filter)
    dialog_state = settings.Setting({
        "directory": "",
        "filter": ""
    })  # type: Dict[str, str]
    MaxHistorySize = 50

    want_main_area = False
    buttons_area_orientation = None

    def __init__(self, *args, **kwargs):
        super().__init__(self, *args, **kwargs)

        self.__committimer = QTimer(self, singleShot=True)
        self.__committimer.timeout.connect(self.commit)

        self.__executor = qconcurrent.ThreadExecutor()
        self.__watcher = None  # type: Optional[qconcurrent.FutureWatcher]

        self.controlArea.layout().setSpacing(-1)  # reset spacing
        grid = QGridLayout()
        grid.addWidget(QLabel("File:", self), 0, 0, 1, 1)

        self.import_items_model = QStandardItemModel(self)
        self.recent_combo = QComboBox(
            self, objectName="recent-combo", toolTip="Recent files.",
            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
            minimumContentsLength=16,
        )
        self.recent_combo.setModel(self.import_items_model)
        self.recent_combo.activated.connect(self.activate_recent)
        self.recent_combo.setSizePolicy(
            QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
        self.browse_button = QPushButton(
            "…", icon=self.style().standardIcon(QStyle.SP_DirOpenIcon),
            toolTip="Browse filesystem", autoDefault=False,
        )
        self.browse_button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        self.browse_button.clicked.connect(self.browse)
        grid.addWidget(self.recent_combo, 0, 1, 1, 1)
        grid.addWidget(self.browse_button, 0, 2, 1, 1)
        self.controlArea.layout().addLayout(grid)

        ###########
        # Info text
        ###########
        box = gui.widgetBox(self.controlArea, "Info", addSpace=False)
        self.summary_text = QTextBrowser(
            verticalScrollBarPolicy=Qt.ScrollBarAsNeeded,
            readOnly=True,
        )
        self.summary_text.viewport().setBackgroundRole(QPalette.NoRole)
        self.summary_text.setFrameStyle(QTextBrowser.NoFrame)
        self.summary_text.setMinimumHeight(self.fontMetrics().ascent() * 2 + 4)
        self.summary_text.viewport().setAutoFillBackground(False)
        box.layout().addWidget(self.summary_text)

        button_box = QDialogButtonBox(
            orientation=Qt.Horizontal,
            standardButtons=QDialogButtonBox.Cancel | QDialogButtonBox.Retry
        )
        self.load_button = b = button_box.button(QDialogButtonBox.Retry)
        b.setText("Load")
        b.clicked.connect(self.__committimer.start)
        b.setEnabled(False)
        b.setDefault(True)

        self.cancel_button = b = button_box.button(QDialogButtonBox.Cancel)
        b.clicked.connect(self.cancel)
        b.setEnabled(False)
        b.setAutoDefault(False)

        self.import_options_button = QPushButton(
            "Import Options…", enabled=False, autoDefault=False,
            clicked=self._activate_import_dialog
        )

        def update_buttons(cbindex):
            self.import_options_button.setEnabled(cbindex != -1)
            self.load_button.setEnabled(cbindex != -1)
        self.recent_combo.currentIndexChanged.connect(update_buttons)

        button_box.addButton(
            self.import_options_button, QDialogButtonBox.ActionRole
        )
        button_box.setStyleSheet(
            "button-layout: {:d};".format(QDialogButtonBox.MacLayout)
        )
        self.controlArea.layout().addWidget(button_box)

        self._restoreState()
        if self.current_item() is not None:
            self._invalidate()
        self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Maximum)

    @Slot(int)
    def activate_recent(self, index):
        """
        Activate an item from the recent list.
        """
        if 0 <= index < self.import_items_model.rowCount():
            item = self.import_items_model.item(index)
            assert item is not None
            path = item.data(ImportItem.PathRole)
            opts = item.data(ImportItem.OptionsRole)
            if not isinstance(opts, Options):
                opts = None
            self.set_selected_file(path, opts)
        else:
            self.recent_combo.setCurrentIndex(-1)

    @Slot()
    def browse(self):
        """
        Open a file dialog and select a user specified file.
        """
        formats = [
            "Text - comma separated (*.csv, *)",
            "Text - tab separated (*.tsv, *)",
            "Text - all files (*)"
        ]

        dlg = QFileDialog(
            self, windowTitle="Open Data File",
            acceptMode=QFileDialog.AcceptOpen,
            fileMode=QFileDialog.ExistingFile
        )
        dlg.setNameFilters(formats)
        state = self.dialog_state
        lastdir = state.get("directory", "")
        lastfilter = state.get("filter", "")

        if lastdir and os.path.isdir(lastdir):
            dlg.setDirectory(lastdir)
        if lastfilter:
            dlg.selectNameFilter(lastfilter)

        status = dlg.exec_()
        dlg.deleteLater()
        if status == QFileDialog.Accepted:
            self.dialog_state["directory"] = dlg.directory().absolutePath()
            self.dialog_state["filter"] = dlg.selectedNameFilter()

            selected_filter = dlg.selectedNameFilter()
            path = dlg.selectedFiles()[0]
            # pre-flight check; try to determine the nature of the file
            mtype = _mime_type_for_path(path)
            if not mtype.inherits("text/plain"):
                mb = QMessageBox(
                    parent=self,
                    windowTitle="",
                    icon=QMessageBox.Question,
                    text="The '{basename}' may be a binary file.\n"
                         "Are you sure you want to continue?".format(
                             basename=os.path.basename(path)),
                    standardButtons=QMessageBox.Cancel | QMessageBox.Yes
                )
                mb.setWindowModality(Qt.WindowModal)
                if mb.exec() == QMessageBox.Cancel:
                    return

            # initialize dialect based on selected extension
            if selected_filter in formats[:-1]:
                filter_idx = formats.index(selected_filter)
                if filter_idx == 0:
                    dialect = csv.excel()
                elif filter_idx == 1:
                    dialect = csv.excel_tab()
                else:
                    dialect = csv.excel_tab()
                header = True
            else:
                try:
                    dialect, header = sniff_csv_with_path(path)
                except Exception:  # pylint: disable=broad-except
                    dialect, header = csv.excel(), True

            options = None
            # Search for path in history.
            # If found use the stored params to initialize the import dialog
            items = self.itemsFromSettings()
            idx = index_where(items, lambda t: samepath(t[0], path))
            if idx is not None:
                _, options_ = items[idx]
                if options_ is not None:
                    options = options_

            if options is None:
                if not header:
                    rowspec = []
                else:
                    rowspec = [(range(0, 1), RowSpec.Header)]
                options = Options(
                    encoding="utf-8", dialect=dialect, rowspec=rowspec)

            dlg = CSVImportDialog(
                self, windowTitle="Import Options", sizeGripEnabled=True)
            dlg.setWindowModality(Qt.WindowModal)
            dlg.setPath(path)
            dlg.setOptions(options)
            status = dlg.exec_()
            dlg.deleteLater()
            if status == QDialog.Accepted:
                self.set_selected_file(path, dlg.options())

    def current_item(self):
        # type: () -> Optional[ImportItem]
        """
        Return the current selected item (file) or None if there is no
        current item.
        """
        idx = self.recent_combo.currentIndex()
        if idx == -1:
            return None

        item = self.recent_combo.model().item(idx)  # type: QStandardItem
        if isinstance(item, ImportItem):
            return item
        else:
            return None

    def _activate_import_dialog(self):
        """Activate the Import Options dialog for the  current item."""
        item = self.current_item()
        assert item is not None
        dlg = CSVImportDialog(
            self, windowTitle="Import Options", sizeGripEnabled=True,
        )
        dlg.setWindowModality(Qt.WindowModal)
        dlg.setAttribute(Qt.WA_DeleteOnClose)
        settings = QSettings()
        qualname = qname(type(self))
        settings.beginGroup(qualname)
        size = settings.value("size", QSize(), type=QSize)  # type: QSize
        if size.isValid():
            dlg.resize(size)

        path = item.data(ImportItem.PathRole)
        options = item.data(ImportItem.OptionsRole)
        dlg.setPath(path)  # Set path before options so column types can
        if isinstance(options, Options):
            dlg.setOptions(options)

        def update():
            newoptions = dlg.options()
            item.setData(newoptions, ImportItem.OptionsRole)
            # update the stored item
            self._add_recent(path, newoptions)
            if newoptions != options:
                self._invalidate()
        dlg.accepted.connect(update)

        def store_size():
            settings.setValue("size", dlg.size())
        dlg.finished.connect(store_size)
        dlg.show()

    def set_selected_file(self, filename, options=None):
        """
        Set the current selected filename path.
        """
        self._add_recent(filename, options)
        self._invalidate()

    #: Saved options for a filename
    SCHEMA = {
        "path": str,  # Local filesystem path
        "options": str,  # json encoded 'Options'
    }

    @classmethod
    def _local_settings(cls):
        # type: () -> QSettings
        """Return a QSettings instance with local persistent settings."""
        filename = "{}.ini".format(qname(cls))
        fname = os.path.join(settings.widget_settings_dir(), filename)
        return QSettings(fname, QSettings.IniFormat)

    def _add_recent(self, filename, options=None):
        # type: (str, Optional[Options]) -> None
        """
        Add filename to the list of recent files.
        """
        model = self.import_items_model
        index = index_where(
            (model.index(i, 0).data(ImportItem.PathRole)
             for i in range(model.rowCount())),
            lambda path: isinstance(path, str) and samepath(path, filename)
        )
        if index is not None:
            item, *_ = model.takeRow(index)
        else:
            item = ImportItem.fromPath(filename)

        model.insertRow(0, item)

        if options is not None:
            item.setOptions(options)

        self.recent_combo.setCurrentIndex(0)
        # store items to local persistent settings
        s = self._local_settings()
        arr = QSettings_readArray(s, "recent", OWCSVFileImport.SCHEMA)
        item = {"path": filename}
        if options is not None:
            item["options"] = json.dumps(options.as_dict())

        arr = [item for item in arr if item.get("path") != filename]
        arr.append(item)
        QSettings_writeArray(s, "recent", arr)

        # update workflow session items
        items = self._session_items[:]
        idx = index_where(items, lambda t: samepath(t[0], filename))
        if idx is not None:
            del items[idx]
        items.insert(0, (filename, options.as_dict()))
        self._session_items = items[:OWCSVFileImport.MaxHistorySize]

    def _invalidate(self):
        # Invalidate the current output and schedule a new commit call.
        # (NOTE: The widget enters a blocking state)
        self.__committimer.start()
        if self.__watcher is not None:
            self.__cancel_task()
        self.setBlocking(True)

    def commit(self):
        """
        Commit the current state and submit the load task for execution.

        Note
        ----
        Any existing pending task is canceled.
        """
        self.__committimer.stop()
        if self.__watcher is not None:
            self.__cancel_task()
        self.error()

        item = self.current_item()
        if item is None:
            return
        path = item.data(ImportItem.PathRole)
        opts = item.data(ImportItem.OptionsRole)
        if not isinstance(opts, Options):
            return

        task = state = TaskState()
        state.future = ...
        state.watcher = qconcurrent.FutureWatcher()
        state.progressChanged.connect(self.__set_read_progress,
                                      Qt.QueuedConnection)

        def progress_(i, j):
            task.emitProgressChangedOrCancel(i, j)

        task.future = self.__executor.submit(
            clear_stack_on_cancel(load_csv),
            path, opts, progress_,
        )
        task.watcher.setFuture(task.future)
        w = task.watcher
        w.done.connect(self.__handle_result)
        w.progress = state
        self.__watcher = w
        self.__set_running_state()

    @Slot('qint64', 'qint64')
    def __set_read_progress(self, read, count):
        if count > 0:
            self.progressBarSet(100 * read / count)

    def __cancel_task(self):
        # Cancel and dispose of the current task
        assert self.__watcher is not None
        w = self.__watcher
        w.future().cancel()
        w.progress.cancel = True
        w.done.disconnect(self.__handle_result)
        w.progress.progressChanged.disconnect(self.__set_read_progress)
        w.progress.deleteLater()
        # wait until completion
        futures.wait([w.future()])
        self.__watcher = None

    def cancel(self):
        """
        Cancel current pending or executing task.
        """
        if self.__watcher is not None:
            self.__cancel_task()
            self.__clear_running_state()
            self.setStatusMessage("Cancelled")
            self.summary_text.setText(
                "<div>Cancelled<br/><small>Press 'Reload' to try again</small></div>"
            )

    def __set_running_state(self):
        self.progressBarInit()
        self.setBlocking(True)
        self.setStatusMessage("Running")
        self.cancel_button.setEnabled(True)
        self.load_button.setText("Restart")
        path = self.current_item().path()
        self.Error.clear()
        self.summary_text.setText(
            "<div>Loading: <i>{}</i><br/>".format(prettyfypath(path))
        )

    def __clear_running_state(self, ):
        self.progressBarFinished()
        self.setStatusMessage("")
        self.setBlocking(False)
        self.cancel_button.setEnabled(False)
        self.load_button.setText("Reload")

    def __set_error_state(self, err):
        self.Error.clear()
        if isinstance(err, UnicodeDecodeError):
            self.Error.encoding_error(exc_info=err)
        else:
            self.Error.error(exc_info=err)

        path = self.current_item().path()
        basename = os.path.basename(path)
        if isinstance(err, UnicodeDecodeError):
            text = (
                "<div><i>{basename}</i> was not loaded due to a text encoding "
                "error. The file might be saved in an unknown or invalid "
                "encoding, or it might be a binary file.</div>"
            ).format(
                basename=escape(basename)
            )
        else:
            text = (
                "<div><i>{basename}</i> was not loaded due to an error:"
                "<p style='white-space: pre;'>{err}</p>"
            ).format(
                basename=escape(basename),
                err="".join(traceback.format_exception_only(type(err), err))
            )
        self.summary_text.setText(text)

    def __clear_error_state(self):
        self.Error.error.clear()
        self.summary_text.setText("")

    def onDeleteWidget(self):
        """Reimplemented."""
        if self.__watcher is not None:
            self.__cancel_task()
            self.__executor.shutdown()
        super().onDeleteWidget()

    @Slot(object)
    def __handle_result(self, f):
        # type: (qconcurrent.Future[pd.DataFrame]) -> None
        assert f.done()
        assert f is self.__watcher.future()
        self.__watcher = None
        self.__clear_running_state()

        try:
            df = f.result()
            assert isinstance(df, pd.DataFrame)
        except pandas.errors.EmptyDataError:
            df = pd.DataFrame({})
        except Exception as e:  # pylint: disable=broad-except
            self.__set_error_state(e)
            df = None
        else:
            self.__clear_error_state()

        if df is not None:
            table = pandas_to_table(df)
        else:
            table = None
        self.send("Data Frame", df)
        self.send('Data', table)
        self._update_status_messages(table)

    def _update_status_messages(self, data):
        if data is None:
            return

        def pluralize(seq):
            return "s" if len(seq) != 1 else ""

        summary = ("{n_instances} row{plural_1}, "
                   "{n_features} feature{plural_2}, "
                   "{n_meta} meta{plural_3}").format(
                       n_instances=len(data), plural_1=pluralize(data),
                       n_features=len(data.domain.attributes),
                       plural_2=pluralize(data.domain.attributes),
                       n_meta=len(data.domain.metas),
                       plural_3=pluralize(data.domain.metas))
        self.summary_text.setText(summary)

    def itemsFromSettings(self):
        # type: () -> List[Tuple[str, Options]]
        """
        Return items from local history.
        """
        s = self._local_settings()
        items_ = QSettings_readArray(s, "recent", OWCSVFileImport.SCHEMA)
        items = []  # type: List[Tuple[str, Options]]
        for item in items_:
            path = item.get("path", "")
            if not path:
                continue
            opts_json = item.get("options", "")
            try:
                opts = Options.from_dict(json.loads(opts_json))
            except (csv.Error, LookupError, TypeError, json.JSONDecodeError):
                _log.error("Could not reconstruct options for '%s'", path,
                           exc_info=True)
            else:
                items.append((path, opts))
        return items[::-1]

    def _restoreState(self):
        # Restore the state. Merge session (workflow) items with the
        # local history.
        model = self.import_items_model
        # local history
        items = self.itemsFromSettings()
        # stored session items
        sitems = []
        for p, m in self._session_items:
            try:
                item_ = (p, Options.from_dict(m))
            except (csv.Error, LookupError):
                # Is it better to fail then to lose a item slot?
                _log.error("Failed to restore '%s'", p, exc_info=True)
            else:
                sitems.append(item_)

        items = sitems + items
        items = unique(items, key=lambda t: pathnormalize(t[0]))

        curr = self.recent_combo.currentIndex()
        if curr != -1:
            currentpath = self.recent_combo.currentData(ImportItem.PathRole)
        else:
            currentpath = None
        for path, options in items:
            item = ImportItem.fromPath(path)
            item.setOptions(options)
            model.appendRow(item)

        if currentpath is not None:
            idx = self.recent_combo.findData(currentpath, ImportItem.PathRole)
            if idx != -1:
                self.recent_combo.setCurrentIndex(idx)
class OWWorldBankIndicators(owwidget_base.OWWidgetBase):
    """World bank data widget for Orange."""
    # pylint: disable=invalid-name
    # Some names have to be invalid to override parent fields.
    # pylint: disable=too-many-ancestors
    # False positive from fetching all ancestors from QWWidget.
    # pylint: disable=too-many-instance-attributes
    # False positive from fetching all attributes from QWWidget.

    # Widget needs a name, or it is considered an abstract widget
    # and not shown in the menu.
    name = "WB Indicators"
    icon = "icons/wb_icon.png"
    outputs = [
        widget.OutputSignal("Data",
                            table.Table,
                            doc="Indicator data from World bank Indicator API")
    ]

    replaces = [
        "Orange.orangecontrib.wbd.widgets.OWWorldBankIndicators",
    ]

    indicator_list_map = collections.OrderedDict([
        (0, "All"),
        (1, "Common"),
        (2, "Featured"),
    ])

    settingsList = [
        "indicator_list_selection",
        "country_selection",
        "indicator_selection",
        "splitterSettings",
        "currentGds",
        "auto_commit",
        "output_type",
    ]

    country_selection = Setting({})
    indicator_selection = Setting([])
    indicator_list_selection = Setting(True)
    output_type = Setting(True)
    auto_commit = Setting(False)

    splitterSettings = Setting(
        (b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xea'
         b'\x00\x00\x00\xd7\x01\x00\x00\x00\x07\x01\x00\x00\x00\x02',
         b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xb5'
         b'\x00\x00\x02\x10\x01\x00\x00\x00\x07\x01\x00\x00\x00\x01'))

    def __init__(self):
        super().__init__()
        self._api = api_wrapper.IndicatorAPI()
        self._init_layout()
        self._check_server_status()

    def _init_layout(self):
        """Initialize widget layout."""

        # Control area
        info_box = gui.widgetBox(self.controlArea, "Info", addSpace=True)
        self._info_label = gui.widgetLabel(info_box, "Initializing\n\n")

        indicator_filter_box = gui.widgetBox(self.controlArea,
                                             "Indicators",
                                             addSpace=True)
        gui.radioButtonsInBox(indicator_filter_box,
                              self,
                              "indicator_list_selection",
                              self.indicator_list_map.values(),
                              "Rows",
                              callback=self.indicator_list_selected)
        self.indicator_list_selection = 2

        gui.separator(indicator_filter_box)

        output_box = gui.widgetBox(self.controlArea, "Output", addSpace=True)
        gui.radioButtonsInBox(output_box,
                              self,
                              "output_type", ["Countries", "Time Series"],
                              "Rows",
                              callback=self.output_type_selected)
        self.output_type = 0

        gui.separator(output_box)
        # pylint: disable=duplicate-code
        gui.auto_commit(self.controlArea,
                        self,
                        "auto_commit",
                        "Commit",
                        box="Commit")
        gui.rubber(self.controlArea)

        # Main area

        gui.widgetLabel(self.mainArea, "Filter")
        self.filter_text = QtWidgets.QLineEdit(
            textChanged=self.filter_indicator_list)
        self.completer = QtWidgets.QCompleter(
            caseSensitivity=QtCore.Qt.CaseInsensitive)
        self.completer.setModel(QtCore.QStringListModel(self))
        self.filter_text.setCompleter(self.completer)

        spliter_v = QtWidgets.QSplitter(QtCore.Qt.Vertical, self.mainArea)

        self.mainArea.layout().addWidget(self.filter_text)
        self.mainArea.layout().addWidget(spliter_v)

        self.indicator_widget = IndicatorsTreeView(spliter_v, main_widget=self)

        splitter_h = QtWidgets.QSplitter(QtCore.Qt.Horizontal, spliter_v)

        self.description_box = gui.widgetBox(splitter_h, "Description")

        self.indicator_description = QtWidgets.QTextEdit()
        self.indicator_description.setReadOnly(True)
        self.description_box.layout().addWidget(self.indicator_description)

        box = gui.widgetBox(splitter_h, "Countries and Regions")
        self.country_tree = CountryTreeWidget(
            splitter_h,
            self.country_selection,
            default_select=True,
            default_colapse=True,
        )
        box.layout().addWidget(self.country_tree)
        self.country_tree.set_data(countries.get_countries_regions_dict())

        self.splitters = spliter_v, splitter_h

        for splitter, setting in zip(self.splitters, self.splitterSettings):
            splitter.splitterMoved.connect(self._splitter_moved)
            splitter.restoreState(setting)

        # self.resize(2000, 600)  # why does this not work

        self.progressBarInit()

    def filter_indicator_list(self):
        """Set the proxy model filter and update info box."""
        filter_string = self.filter_text.text()
        proxy_model = self.indicator_widget.model()
        if proxy_model:
            strings = filter_string.lower().strip().split()
            proxy_model.setFilterFixedStrings(strings)
            self.print_info()

    def output_type_selected(self):
        self.commit_if()

    def basic_indicator_filter(self):
        return self.indicator_list_map.get(self.indicator_list_selection)

    def indicator_list_selected(self):
        """Update basic indicator selection.

        Switch indicator list selection between All, Common, and Featured.
        """
        value = self.basic_indicator_filter()
        logger.debug("Indicator list selected: %s", value)
        self.indicator_widget.fetch_indicators()

    def _splitter_moved(self, *_):
        self.splitterSettings = [
            bytes(sp.saveState()) for sp in self.splitters
        ]

    def _fetch_dataset(self, set_progress=None):
        """Fetch indicator dataset."""
        set_progress(0)
        self._start_progerss_task()

        country_codes = self.get_country_codes()

        if len(country_codes) > 250:
            country_codes = None
        logger.debug("Fetch: selected country codes: %s", country_codes)
        logger.debug("Fetch: selected indicators: %s",
                     self.indicator_selection)
        indicator_dataset = self._api.get_dataset(self.indicator_selection,
                                                  countries=country_codes)
        self._set_progress_flag = False
        return indicator_dataset

    def _dataset_to_table(self, dataset):
        time_series = self.output_type == 1
        return dataset.as_orange_table(time_series=time_series)

    @staticmethod
    def _fetch_dataset_exception(exception):
        logger.exception(exception)

    def _dataset_progress(self, set_progress=None):
        """Update dataset download progress.

        This function reads the progress state from the world bank API and sets
        the current widgets progress to that. All This thread should only read
        data and ask the GUI thread to update the progress for this to be
        thread safe.
        """
        while self._set_progress_flag:
            indicators = self._api.progress["indicators"]
            current_indicator = self._api.progress["current_indicator"]
            indicator_pages = self._api.progress["indicator_pages"]
            current_page = self._api.progress["current_page"]
            logger.debug("api progress: %s", self._api.progress)
            if indicator_pages > 0 and indicators > 0:
                progress = (((100 / indicators) *
                             (current_indicator - 1)) + (100 / indicators) *
                            (current_page / indicator_pages))
                logger.debug("calculated progress: %s", progress)
                set_progress(math.floor(progress))
            time.sleep(1)

    def _dataset_progress_exception(self, exception):
        logger.exception(exception)
        self.print_info()
Пример #8
0
class OWWorldBankClimate(owwidget_base.OWWidgetBase):
    """World bank data widget for Orange."""
    # pylint: disable=invalid-name
    # Some names have to be invalid to override parent fields.
    # pylint: disable=too-many-ancestors
    # False positive from fetching all ancestors from QWWidget.
    # pylint: disable=too-many-instance-attributes
    # False positive from fetching all attributes from QWWidget.

    # Widget needs a name, or it is considered an abstract widget
    # and not shown in the menu.
    name = "WB Climate"
    icon = "icons/climate.png"
    outputs = [
        widget.OutputSignal("Data",
                            table.Table,
                            doc="Climate data from World bank Climate API")
    ]

    replaces = [
        "Orange.orangecontrib.wbd.widgets.OWWorldBankClimate",
    ]

    settingsList = [
        "auto_commit",
        "country_selection",
        "mergeSpots",
        "output_type"
        "splitterSettings",
    ]

    country_selection = Setting({})
    output_type = Setting(True)
    mergeSpots = Setting(True)
    auto_commit = Setting(False)
    use_country_names = Setting(False)

    include_intervals = Setting([])
    include_data_types = Setting([])

    def _data_type_setter(self, name, value):
        intervals = set(self.include_data_types) | {name}
        if not value:
            intervals.remove(name)
        self.include_data_types = list(intervals)
        logger.debug("New intervals: %s", self.include_data_types)

    def _interval_setter(self, name, value):
        intervals = set(self.include_intervals) | {name}
        if not value:
            intervals.remove(name)
        self.include_intervals = list(intervals)
        logger.debug("New intervals: %s", self.include_intervals)

    @property
    def include_month(self):
        return "month" in self.include_intervals

    @include_month.setter
    def include_month(self, value):
        self._interval_setter("month", value)

    @property
    def include_year(self):
        return "year" in self.include_intervals

    @include_year.setter
    def include_year(self, value):
        self._interval_setter("year", value)

    @property
    def include_decade(self):
        return "decade" in self.include_intervals

    @include_decade.setter
    def include_decade(self, value):
        self._interval_setter("decade", value)

    @property
    def include_temperature(self):
        return "tas" in self.include_data_types

    @include_temperature.setter
    def include_temperature(self, value):
        self._data_type_setter("tas", value)

    @property
    def include_precipitation(self):
        return "pr" in self.include_data_types

    @include_precipitation.setter
    def include_precipitation(self, value):
        self._data_type_setter("pr", value)

    splitterSettings = Setting((
        b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xea\x00'
        b'\x00\x00\xd7\x01\x00\x00\x00\x07\x01\x00\x00\x00\x02',
        b'\x00\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x01\xb5\x00'
        b'\x00\x02\x10\x01\x00\x00\x00\x07\x01\x00\x00\x00\x01'))

    def __init__(self):
        super().__init__()
        self._api = api_wrapper.ClimateAPI()
        self._init_layout()
        self.print_selection_count()
        try:
            self._check_server_status()
        except ConnectionError:
            pass

    def print_selection_count(self):
        """Update info widget with new selection count."""
        country_codes = self.get_country_codes()
        self.info_data["Selected countries"] = len(country_codes)
        self.print_info()

    def _init_layout(self):
        """Initialize widget layout."""

        # Control area
        info_box = gui.widgetBox(self.controlArea, "Info", addSpace=True)
        self._info_label = gui.widgetLabel(info_box, "Initializing\n\n")

        box = gui.vBox(self.controlArea, "Average intervals:")
        self.ch_month = gui.checkBox(box,
                                     self,
                                     "include_month",
                                     "Month",
                                     callback=self.commit_if)
        self.ch_year = gui.checkBox(box,
                                    self,
                                    "include_year",
                                    'Year',
                                    callback=self.commit_if)
        self.ch_decade = gui.checkBox(box,
                                      self,
                                      "include_decade",
                                      'Decade',
                                      callback=self.commit_if)

        box = gui.vBox(self.controlArea, "Data Types")
        gui.checkBox(box,
                     self,
                     "include_temperature",
                     "Temperature",
                     callback=self.commit_if)
        gui.checkBox(box,
                     self,
                     "include_precipitation",
                     'Precipitation',
                     callback=self.commit_if)

        output_box = gui.widgetBox(self.controlArea, "Output", addSpace=True)
        gui.radioButtonsInBox(output_box,
                              self,
                              "output_type", ["Countries", "Time Series"],
                              "Rows",
                              callback=self.output_type_selected)

        gui.checkBox(output_box,
                     self,
                     "use_country_names",
                     "Use Country names",
                     callback=self.commit_if)
        self.output_type = 0

        # pylint: disable=duplicate-code
        gui.separator(output_box)
        gui.auto_commit(self.controlArea,
                        self,
                        "auto_commit",
                        "Commit",
                        box="Commit")
        gui.rubber(self.controlArea)

        # Main area
        box = gui.widgetBox(self.mainArea, "Countries")
        self.country_tree = CountryTreeWidget(
            self.mainArea,
            self.country_selection,
            commit_callback=self.commit_if,
            default_colapse=True,
            default_select=False,
        )
        countriesdict = countries.get_countries_dict()
        if countriesdict is not None:
            self.country_tree.set_data(countriesdict)
        box.layout().addWidget(self.country_tree)
        self.resize(500, 400)  # why does this not work

    def output_type_selected(self):
        """Output type handle."""
        logger.debug("output type set to: %s", self.output_type)
        if self.output_type == 1:  # Time series
            self.ch_decade.setEnabled(False)
            self.ch_month.setEnabled(False)
            self.ch_year.setEnabled(False)
            self.include_year = True
            self.include_month = False
            self.include_decade = False
        else:
            self.ch_decade.setEnabled(True)
            self.ch_month.setEnabled(True)
            self.ch_year.setEnabled(True)
        self.commit_if()

    def _splitter_moved(self, *_):
        self.splitterSettings = [
            bytes(sp.saveState()) for sp in self.splitters
        ]

    def _check_big_selection(self):
        types = len(self.include_data_types) if self.include_data_types else 2
        intervals = len(
            self.include_intervals) if self.include_intervals else 2
        country_codes = self.get_country_codes()
        selected_countries = len(country_codes)
        if types * intervals * selected_countries > 100:
            self.info_data[
                "Warning"] = "Fetching data\nmight take a few minutes."
        else:
            self.info_data["Warning"] = None
        self.print_info()

    def commit_if(self):
        """Auto commit handler.

        This function must be called on every action that should trigger an
        auto commit.
        """
        self._check_big_selection()
        self.print_selection_count()
        super().commit_if()

    def _fetch_dataset(self, set_progress=None):
        """Fetch climate dataset."""

        set_progress(0)
        self._start_progerss_task()

        country_codes = self.get_country_codes()

        logger.debug("Fetch: selected country codes: %s", country_codes)
        climate_dataset = self._api.get_instrumental(
            country_codes,
            data_types=self.include_data_types,
            intervals=self.include_intervals)
        self._set_progress_flag = False
        return climate_dataset

    def _dataset_to_table(self, dataset):
        time_series = self.output_type == 1
        return dataset.as_orange_table(
            time_series=time_series,
            use_names=self.use_country_names,
        )

    def _dataset_progress(self, set_progress=None):
        while self._set_progress_flag:
            pages = self._api.progress["pages"]
            current_page = self._api.progress["current_page"]
            logger.debug("api progress: %s", self._api.progress)
            if pages > 0:
                progress = ((100 / pages) * (current_page - 1))
                logger.debug("calculated progress: %s", progress)
                set_progress(math.floor(progress))
            time.sleep(1)

    @staticmethod
    def _dataset_progress_exception(exception):
        logger.exception(exception)
Пример #9
0
class OWMergeData(widget.OWWidget):
    name = "Merge Data"
    description = "Merge data sets based on the values of selected data features."
    icon = "icons/MergeData.svg"
    priority = 1110

    inputs = [("Data A", Orange.data.Table, "setDataA", widget.Default),
              ("Data B", Orange.data.Table, "setDataB")]
    outputs = [
        widget.OutputSignal("Merged Data",
                            Orange.data.Table,
                            replaces=["Merged Data A+B", "Merged Data B+A"])
    ]

    attr_a = settings.Setting('', schema_only=True)
    attr_b = settings.Setting('', schema_only=True)
    inner = settings.Setting(True)

    want_main_area = False

    def __init__(self):
        super().__init__()

        # data
        self.dataA = None
        self.dataB = None

        # GUI
        box = gui.hBox(self.controlArea, "Match instances by")

        # attribute A selection
        self.attrViewA = gui.comboBox(box,
                                      self,
                                      'attr_a',
                                      label="Data A",
                                      orientation=Qt.Vertical,
                                      sendSelectedValue=True,
                                      callback=self._invalidate)
        self.attrModelA = itemmodels.VariableListModel()
        self.attrViewA.setModel(self.attrModelA)

        # attribute  B selection
        self.attrViewB = gui.comboBox(box,
                                      self,
                                      'attr_b',
                                      label="Data B",
                                      orientation=Qt.Vertical,
                                      sendSelectedValue=True,
                                      callback=self._invalidate)
        self.attrModelB = itemmodels.VariableListModel()
        self.attrViewB.setModel(self.attrModelB)

        # info A
        box = gui.hBox(self.controlArea, box=None)
        self.infoBoxDataA = gui.label(box,
                                      self,
                                      self.dataInfoText(None),
                                      box="Data A Info")

        # info B
        self.infoBoxDataB = gui.label(box,
                                      self,
                                      self.dataInfoText(None),
                                      box="Data B Info")

        gui.separator(self.controlArea)
        box = gui.vBox(self.controlArea, box=True)
        gui.checkBox(box,
                     self,
                     "inner",
                     "Exclude instances without a match",
                     callback=self._invalidate)

    def _setAttrs(self, model, data, othermodel, otherdata):
        model[:] = allvars(data) if data is not None else []

        if data is not None and otherdata is not None and \
                len(numpy.intersect1d(data.ids, otherdata.ids)):
            for model_ in (model, othermodel):
                if len(model_) and model_[0] != INSTANCEID:
                    model_.insert(0, INSTANCEID)

    @check_sql_input
    def setDataA(self, data):
        self.dataA = data
        self._setAttrs(self.attrModelA, data, self.attrModelB, self.dataB)
        curr_index = -1
        if self.attr_a:
            curr_index = next((i for i, val in enumerate(self.attrModelA)
                               if str(val) == self.attr_a), -1)
        if curr_index != -1:
            self.attrViewA.setCurrentIndex(curr_index)
        else:
            self.attr_a = INDEX
        self.infoBoxDataA.setText(self.dataInfoText(data))

    @check_sql_input
    def setDataB(self, data):
        self.dataB = data
        self._setAttrs(self.attrModelB, data, self.attrModelA, self.dataA)
        curr_index = -1
        if self.attr_b:
            curr_index = next((i for i, val in enumerate(self.attrModelB)
                               if str(val) == self.attr_b), -1)
        if curr_index != -1:
            self.attrViewB.setCurrentIndex(curr_index)
        else:
            self.attr_b = INDEX
        self.infoBoxDataB.setText(self.dataInfoText(data))

    def handleNewSignals(self):
        self._invalidate()

    def dataInfoText(self, data):
        ninstances = 0
        nvariables = 0
        if data is not None:
            ninstances = len(data)
            nvariables = len(data.domain)

        instances = self.tr("%n instance(s)", None, ninstances)
        attributes = self.tr("%n variable(s)", None, nvariables)
        return "\n".join([instances, attributes])

    def commit(self):
        AB = None
        if (self.attr_a and self.attr_b and self.dataA is not None
                and self.dataB is not None):
            varA = (self.attr_a if self.attr_a in (INDEX, INSTANCEID) else
                    self.dataA.domain[self.attr_a])
            varB = (self.attr_b if self.attr_b in (INDEX, INSTANCEID) else
                    self.dataB.domain[self.attr_b])
            AB = merge(self.dataA, varA, self.dataB, varB, self.inner)
        self.send("Merged Data", AB)

    def _invalidate(self):
        self.commit()

    def send_report(self):
        attr_a = None
        attr_b = None
        if self.dataA is not None:
            attr_a = self.attr_a
            if attr_a in self.dataA.domain:
                attr_a = self.dataA.domain[attr_a]
        if self.dataB is not None:
            attr_b = self.attr_b
            if attr_b in self.dataB.domain:
                attr_b = self.dataB.domain[attr_b]
        self.report_items((
            ("Attribute A", attr_a),
            ("Attribute B", attr_b),
        ))
Пример #10
0
class OWFile(widget.OWWidget):
    name = "File"
    id = "orange.widgets.data.file"
    description = "Read a data from an input file or network" \
                  "and send the data table to the output."
    icon = "icons/File.svg"
    priority = 10
    category = "Data"
    keywords = ["data", "file", "load", "read"]
    outputs = [
        widget.OutputSignal(
            "Data",
            Table,
            doc="Attribute-valued data set read from the input file.")
    ]

    want_main_area = False
    resizing_enabled = False

    LOCAL_FILE, URL = range(2)

    #: List[RecentPath]
    recent_paths = Setting([])
    recent_urls = Setting([])
    source = Setting(LOCAL_FILE)
    url = Setting("")

    dlg_formats = ("All readable files ({});;".format(
        '*' + ' *'.join(FileFormat.readers.keys())) + ";;".join(
            "{} (*{})".format(f.DESCRIPTION, ' *'.join(f.EXTENSIONS))
            for f in sorted(set(FileFormat.readers.values()),
                            key=list(FileFormat.readers.values()).index)))

    def __init__(self):
        super().__init__()
        self.domain = None
        self.data = None
        self.loaded_file = ""
        self._relocate_recent_files()

        vbox = gui.radioButtons(self.controlArea,
                                self,
                                "source",
                                box=True,
                                addSpace=True,
                                callback=self.load_data)

        box = gui.widgetBox(vbox, orientation="horizontal")
        gui.appendRadioButton(vbox, "File", insertInto=box)
        self.file_combo = QtGui.QComboBox(
            box, sizeAdjustPolicy=QtGui.QComboBox.AdjustToContents)
        self.file_combo.setMinimumWidth(250)
        box.layout().addWidget(self.file_combo)
        self.file_combo.activated[int].connect(self.select_file)
        button = gui.button(box,
                            self,
                            '...',
                            callback=self.browse_file,
                            autoDefault=False)
        button.setIcon(self.style().standardIcon(QtGui.QStyle.SP_DirOpenIcon))
        button.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed)
        button = gui.button(box,
                            self,
                            "Reload",
                            callback=self.reload,
                            autoDefault=False)
        button.setIcon(self.style().standardIcon(
            QtGui.QStyle.SP_BrowserReload))
        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)

        box = gui.widgetBox(vbox, orientation="horizontal")
        gui.appendRadioButton(vbox, "URL", insertInto=box)
        self.le_url = le_url = QtGui.QLineEdit(self.url)
        l, t, r, b = le_url.getTextMargins()
        le_url.setTextMargins(l + 5, t, r, b)
        le_url.editingFinished.connect(self._url_set)
        box.layout().addWidget(le_url)

        self.completer_model = PyListModel()
        self.completer_model.wrap(self.recent_urls)
        completer = QtGui.QCompleter()
        completer.setModel(self.completer_model)
        completer.setCompletionMode(completer.PopupCompletion)
        completer.setCaseSensitivity(QtCore.Qt.CaseInsensitive)
        le_url.setCompleter(completer)

        box = gui.widgetBox(self.controlArea, "Info")
        self.info = gui.widgetLabel(box, 'No data loaded.')
        self.warnings = gui.widgetLabel(box, '')

        box = gui.widgetBox(self.controlArea, orientation="horizontal")
        gui.button(box,
                   self,
                   "Browse documentation data sets",
                   callback=lambda: self.browse_file(True),
                   autoDefault=False)
        gui.rubber(box)
        box.layout().addWidget(self.report_button)
        self.report_button.setFixedWidth(170)

        # Set word wrap, so long warnings won't expand the widget
        self.warnings.setWordWrap(True)
        self.warnings.setSizePolicy(QSizePolicy.Ignored,
                                    QSizePolicy.MinimumExpanding)

        self.set_file_list()
        # Must not call open_file from within __init__. open_file
        # explicitly re-enters the event loop (by a progress bar)
        QtCore.QTimer.singleShot(0, self.load_data)

    def _relocate_recent_files(self):
        paths = [("sample-datasets", get_sample_datasets_dir())]
        basedir = self.workflowEnv().get("basedir", None)
        if basedir is not None:
            paths.append(("basedir", basedir))

        rec = []
        for recent in self.recent_paths:
            resolved = recent.resolve(paths)
            if resolved is not None:
                rec.append(RecentPath.create(resolved.abspath, paths))
            elif recent.search(paths) is not None:
                rec.append(RecentPath.create(recent.search(paths), paths))
        self.recent_paths = rec

    def set_file_list(self):
        self.file_combo.clear()
        if not self.recent_paths:
            self.file_combo.addItem("(none)")
            self.file_combo.model().item(0).setEnabled(False)
        else:
            for i, recent in enumerate(self.recent_paths):
                self.file_combo.addItem(recent.value)
                self.file_combo.model().item(i).setToolTip(recent.abspath)

    def reload(self):
        if self.recent_paths:
            basename = self.file_combo.currentText()
            path = self.recent_paths[0]
            if basename in [path.relpath, path.value]:
                self.source = self.LOCAL_FILE
                return self.load_data()
        self.select_file(len(self.recent_paths) + 1)

    def select_file(self, n):
        if n < len(self.recent_paths):
            recent = self.recent_paths[n]
            del self.recent_paths[n]
            self.recent_paths.insert(0, recent)
        elif n:
            path = self.file_combo.currentText()
            if os.path.exists(path):
                self._add_path(path)
            else:
                self.info.setText('Data was not loaded:')
                self.warnings.setText("File {} does not exist".format(path))
                self.file_combo.removeItem(n)
                self.file_combo.lineEdit().setText(path)
                return

        if len(self.recent_paths) > 0:
            self.source = self.LOCAL_FILE
            self.load_data()
            self.set_file_list()

    def _url_set(self):
        self.source = self.URL
        self.load_data()

    def browse_file(self, in_demos=False):
        if in_demos:
            try:
                start_file = get_sample_datasets_dir()
            except AttributeError:
                start_file = ""
            if not start_file or not os.path.exists(start_file):
                widgets_dir = os.path.dirname(gui.__file__)
                orange_dir = os.path.dirname(widgets_dir)
                start_file = os.path.join(orange_dir, "doc", "datasets")
            if not start_file or not os.path.exists(start_file):
                d = os.getcwd()
                if os.path.basename(d) == "canvas":
                    d = os.path.dirname(d)
                start_file = os.path.join(os.path.dirname(d), "doc",
                                          "datasets")
            if not os.path.exists(start_file):
                QtGui.QMessageBox.information(
                    None, "File",
                    "Cannot find the directory with example data sets")
                return
        else:
            if self.recent_paths:
                start_file = self.recent_paths[0].abspath
            else:
                start_file = os.path.expanduser("~/")

        filename = QtGui.QFileDialog.getOpenFileName(self,
                                                     'Open Orange Data File',
                                                     start_file,
                                                     self.dlg_formats)
        if not filename:
            return

        self._add_path(filename)
        self.set_file_list()
        self.source = self.LOCAL_FILE
        self.load_data()

    def _add_path(self, filename):
        searchpaths = [("sample-datasets", get_sample_datasets_dir())]
        basedir = self.workflowEnv().get("basedir", None)
        if basedir is not None:
            searchpaths.append(("basedir", basedir))
        recent = RecentPath.create(filename, searchpaths)
        if recent in self.recent_paths:
            self.recent_paths.remove(recent)
        self.recent_paths.insert(0, recent)

    # Open a file, create data from it and send it over the data channel
    def load_data(self):
        def load(method, fn):
            with catch_warnings(record=True) as warnings:
                data = method(fn)
                self.warning(33,
                             warnings[-1].message.args[0] if warnings else '')
            return data, fn

        def load_from_file():
            fn = fn_original = self.recent_paths[0].abspath
            if fn == "(none)":
                return None, ""
            if not os.path.exists(fn):
                dir_name, basename = os.path.split(fn)
                if os.path.exists(os.path.join(".", basename)):
                    fn = os.path.join(".", basename)
                    self.information(
                        "Loading '{}' from the current directory.".format(
                            basename))
            try:
                return load(Table.from_file, fn)
            except Exception as exc:
                self.warnings.setText(str(exc))
                ind = self.file_combo.currentIndex()
                self.file_combo.removeItem(ind)
                if ind < len(self.recent_paths) and \
                        self.recent_paths[ind].abspath == fn_original:
                    del self.recent_paths[ind]
                raise

        def load_from_network():
            def update_model():
                try:
                    self.completer_model.remove(url or self.url)
                except ValueError:
                    pass
                self.completer_model.insert(0, url)

            self.url = url = self.le_url.text()
            if url:
                QtCore.QTimer.singleShot(0, update_model)
            if not url:
                return None, ""
            elif "://" not in url:
                url = "http://" + url
            try:
                return load(Table.from_url, url)
            except:
                self.warnings.setText(
                    "URL '{}' does not contain valid data".format(url))
                # Don't remove from recent_urls:
                # resource may reappear, or the user mistyped it
                # and would like to retrieve it from history and fix it.
                raise

        self.warning()
        self.information()

        try:
            self.data, self.loaded_file = \
                [load_from_file, load_from_network][self.source]()
        except:
            self.info.setText("Data was not loaded:")
            self.data = None
            self.loaded_file = ""
            return
        else:
            self.warnings.setText("")

        data = self.data
        if data is None:
            self.send("Data", None)
            self.info.setText("No data loaded")
            return

        domain = data.domain
        text = "{} instance(s), {} feature(s), {} meta attribute(s)".format(
            len(data), len(domain.attributes), len(domain.metas))
        if domain.has_continuous_class:
            text += "\nRegression; numerical class."
        elif domain.has_discrete_class:
            text += "\nClassification; discrete class with {} values.".format(
                len(domain.class_var.values))
        elif data.domain.class_vars:
            text += "\nMulti-target; {} target variables.".format(
                len(data.domain.class_vars))
        else:
            text += "\nData has no target variable."
        if 'Timestamp' in data.domain:
            # Google Forms uses this header to timestamp responses
            text += '\n\nFirst entry: {}\nLast entry: {}'.format(
                data[0, 'Timestamp'], data[-1, 'Timestamp'])
        self.info.setText(text)

        add_origin(data, self.loaded_file)
        self.send("Data", data)

    def get_widget_name_extension(self):
        _, name = os.path.split(self.loaded_file)
        return os.path.splitext(name)[0]

    def send_report(self):
        def get_ext_name(filename):
            try:
                return FileFormat.names[os.path.splitext(filename)[1]]
            except KeyError:
                return "unknown"

        if self.data is None:
            self.report_paragraph("File", "No file.")
            return

        if self.source == self.LOCAL_FILE:
            home = os.path.expanduser("~")
            if self.loaded_file.startswith(home):
                # os.path.join does not like ~
                name = "~/" + \
                       self.loaded_file[len(home):].lstrip("/").lstrip("\\")
            else:
                name = self.loaded_file
            self.report_items("File", [("File name", name),
                                       ("Format", get_ext_name(name))])
        else:
            self.report_items("Data", [("URL", self.url),
                                       ("Format", get_ext_name(self.url))])

        self.report_data("Data", self.data)

    def workflowEnvChanged(self, key, value, oldvalue):
        if key == "basedir":
            self._relocate_recent_files()
            self.set_file_list()
Пример #11
0
class OWMergeData(widget.OWWidget):
    name = "Merge Data"
    description = "Merge data sets based on the values of selected features."
    icon = "icons/MergeData.svg"
    priority = 1110

    inputs = [
        widget.InputSignal("Data",
                           Orange.data.Table,
                           "setData",
                           widget.Default,
                           replaces=["Data A"]),
        widget.InputSignal("Extra Data",
                           Orange.data.Table,
                           "setExtraData",
                           replaces=["Data B"])
    ]
    outputs = [
        widget.OutputSignal(
            "Data",
            Orange.data.Table,
            replaces=["Merged Data A+B", "Merged Data B+A", "Merged Data"])
    ]

    attr_augment_data = settings.Setting('', schema_only=True)
    attr_augment_extra = settings.Setting('', schema_only=True)
    attr_merge_data = settings.Setting('', schema_only=True)
    attr_merge_extra = settings.Setting('', schema_only=True)
    attr_combine_data = settings.Setting('', schema_only=True)
    attr_combine_extra = settings.Setting('', schema_only=True)
    merging = settings.Setting(0)

    want_main_area = False
    resizing_enabled = False

    class Warning(widget.OWWidget.Warning):
        duplicate_names = widget.Msg("Duplicate variable names in output.")

    def __init__(self):
        super().__init__()

        self.data = None
        self.extra_data = None
        self.extra_data = None

        self.model = itemmodels.VariableListModel()
        self.model_unique_with_id = itemmodels.VariableListModel()
        self.extra_model_unique = itemmodels.VariableListModel()
        self.extra_model_unique_with_id = itemmodels.VariableListModel()

        box = gui.hBox(self.controlArea, box=None)
        self.infoBoxData = gui.label(box,
                                     self,
                                     self.dataInfoText(None),
                                     box="Data")
        self.infoBoxExtraData = gui.label(box,
                                          self,
                                          self.dataInfoText(None),
                                          box="Extra Data")

        grp = gui.radioButtonsInBox(self.controlArea,
                                    self,
                                    "merging",
                                    box="Merging",
                                    callback=self.change_merging)
        self.attr_boxes = []

        radio_width = \
            QApplication.style().pixelMetric(QStyle.PM_ExclusiveIndicatorWidth)

        def add_option(label, pre_label, between_label, merge_type, model,
                       extra_model):
            gui.appendRadioButton(grp, label)
            vbox = gui.vBox(grp)
            box = gui.hBox(vbox)
            box.layout().addSpacing(radio_width)
            self.attr_boxes.append(box)
            gui.widgetLabel(box, pre_label)
            model[:] = [getattr(self, 'attr_{}_data'.format(merge_type))]
            extra_model[:] = [
                getattr(self, 'attr_{}_extra'.format(merge_type))
            ]
            cb = gui.comboBox(box,
                              self,
                              'attr_{}_data'.format(merge_type),
                              callback=self._invalidate,
                              model=model)
            cb.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
            cb.setFixedWidth(190)
            gui.widgetLabel(box, between_label)
            cb = gui.comboBox(box,
                              self,
                              'attr_{}_extra'.format(merge_type),
                              callback=self._invalidate,
                              model=extra_model)
            cb.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
            cb.setFixedWidth(190)
            vbox.layout().addSpacing(6)

        add_option("Append columns from Extra Data", "by matching", "with",
                   "augment", self.model, self.extra_model_unique)
        add_option("Find matching rows", "where", "equals", "merge",
                   self.model_unique_with_id, self.extra_model_unique_with_id)
        add_option("Concatenate tables, merge rows", "where", "equals",
                   "combine", self.model_unique_with_id,
                   self.extra_model_unique_with_id)
        self.set_merging()

    def set_merging(self):
        # pylint: disable=invalid-sequence-index
        # all boxes should be hidden before one is shown, otherwise widget's
        # layout changes height
        for box in self.attr_boxes:
            box.hide()
        self.attr_boxes[self.merging].show()

    def change_merging(self):
        self.set_merging()
        self._invalidate()

    @staticmethod
    def _set_unique_model(data, model):
        if data is None:
            model[:] = []
            return
        m = [INDEX]
        for attr in chain(data.domain.variables, data.domain.metas):
            col = data.get_column_view(attr)[0]
            if attr.is_primitive():
                col = col.astype(float)
                col = col[~np.isnan(col)]
            else:
                col = col[~(col == "")]
            if len(np.unique(col)) == len(col):
                m.append(attr)
        model[:] = m

    @staticmethod
    def _set_model(data, model):
        if data is None:
            model[:] = []
            return
        model[:] = list(chain([INDEX], data.domain, data.domain.metas))

    def _add_instanceid_to_models(self):
        needs_id = self.data is not None and self.extra_data is not None and \
            len(np.intersect1d(self.data.ids, self.extra_data.ids))
        for model in (self.model_unique_with_id,
                      self.extra_model_unique_with_id):
            has_id = INSTANCEID in model
            if needs_id and not has_id:
                model.insert(0, INSTANCEID)
            elif not needs_id and has_id:
                model.remove(INSTANCEID)

    def _init_combo_current_items(self, variables, models):
        for var, model in zip(variables, models):
            value = getattr(self, var)
            if len(model) > 0:
                setattr(self, var, value if value in model else INDEX)

    def _find_best_match(self):
        def get_unique_str_metas_names(model_):
            return [m for m in model_ if isinstance(m, StringVariable)]

        def best_match(model, extra_model):
            attr, extra_attr, n_max_intersect = INDEX, INDEX, 0
            str_metas = get_unique_str_metas_names(model)
            extra_str_metas = get_unique_str_metas_names(extra_model)
            for m_a, m_b in product(str_metas, extra_str_metas):
                n_inter = len(
                    np.intersect1d(self.data[:, m_a].metas,
                                   self.extra_data[:, m_b].metas))
                if n_inter > n_max_intersect:
                    n_max_intersect, attr, extra_attr = n_inter, m_a, m_b
            return attr, extra_attr

        def set_attrs(attr_name, attr_extra_name, attr, extra_attr):
            if getattr(self, attr_name) == INDEX and \
                    getattr(self, attr_extra_name) == INDEX:
                setattr(self, attr_name, attr)
                setattr(self, attr_extra_name, extra_attr)

        if self.data and self.extra_data:
            attrs = best_match(self.model, self.extra_model_unique)
            set_attrs("attr_augment_data", "attr_augment_extra", *attrs)
            attrs = best_match(self.model_unique_with_id,
                               self.extra_model_unique_with_id)
            set_attrs("attr_merge_data", "attr_merge_extra", *attrs)
            set_attrs("attr_combine_data", "attr_combine_extra", *attrs)

    @check_sql_input
    def setData(self, data):
        self.data = data
        self._set_model(data, self.model)
        self._set_unique_model(data, self.model_unique_with_id)
        self._add_instanceid_to_models()
        self._init_combo_current_items(
            ("attr_augment_data", "attr_merge_data", "attr_combine_data"),
            (self.model, self.model_unique_with_id, self.model_unique_with_id))
        self.infoBoxData.setText(self.dataInfoText(data))
        self._find_best_match()

    @check_sql_input
    def setExtraData(self, data):
        self.extra_data = data
        self._set_unique_model(data, self.extra_model_unique)
        self._set_unique_model(data, self.extra_model_unique_with_id)
        self._add_instanceid_to_models()
        self._init_combo_current_items(
            ("attr_augment_extra", "attr_merge_extra", "attr_combine_extra"),
            (self.extra_model_unique, self.extra_model_unique_with_id,
             self.extra_model_unique_with_id))
        self.infoBoxExtraData.setText(self.dataInfoText(data))
        self._find_best_match()

    def handleNewSignals(self):
        self._invalidate()

    def dataInfoText(self, data):
        if data is None:
            return "No data."
        else:
            return "{}\n{} instances\n{} variables".format(
                data.name, len(data),
                len(data.domain) + len(data.domain.metas))

    def commit(self):
        self.Warning.duplicate_names.clear()
        if self.data is None or len(self.data) == 0 or \
                self.extra_data is None or len(self.extra_data) == 0:
            merged_data = None
        else:
            merged_data = self.merge()
            if merged_data:
                merged_domain = merged_data.domain
                var_names = [
                    var.name for var in chain(merged_domain.variables,
                                              merged_domain.metas)
                ]
                if len(set(var_names)) != len(var_names):
                    self.Warning.duplicate_names()
        self.send("Data", merged_data)

    def _invalidate(self):
        self.commit()

    def send_report(self):
        # pylint: disable=invalid-sequence-index
        attr = (self.attr_augment_data, self.attr_merge_data,
                self.attr_combine_data)
        extra_attr = (self.attr_augment_extra, self.attr_merge_extra,
                      self.attr_combine_extra)
        merging_types = ("Append columns from Extra Data",
                         "Find matching rows",
                         "Concatenate tables, merge rows")
        self.report_items((("Merging", merging_types[self.merging]),
                           ("Data attribute", attr[self.merging]),
                           ("Extra data attribute", extra_attr[self.merging])))

    def merge(self):
        # pylint: disable=invalid-sequence-index
        operation = ["augment", "merge", "combine"][self.merging]
        var_data = getattr(self, "attr_{}_data".format(operation))
        var_extra_data = getattr(self, "attr_{}_extra".format(operation))
        merge_method = getattr(self, "_{}_indices".format(operation))

        as_string = not (isinstance(var_data, ContinuousVariable)
                         and isinstance(var_extra_data, ContinuousVariable))
        extra_map = self._get_keymap(self.extra_data, var_extra_data,
                                     as_string)
        match_indices = merge_method(var_data, extra_map, as_string)
        reduced_extra_data = self._compute_reduced_extra_data(var_extra_data)
        return self._join_table_by_indices(reduced_extra_data, match_indices)

    def _compute_reduced_extra_data(self, var_extra_data):
        """Prepare a table with extra columns that will appear in the merged
        table"""
        domain = self.data.domain
        extra_domain = self.extra_data.domain
        all_vars = set(chain(domain.variables, domain.metas))
        if self.merging != MergeType.OUTER_JOIN:
            all_vars.add(var_extra_data)
        extra_vars = chain(extra_domain.variables, extra_domain.metas)
        return self.extra_data[:, [
            var for var in extra_vars if var not in all_vars
        ]]

    @staticmethod
    def _values(data, var, as_string):
        """Return an iterotor over keys for rows of the table."""
        if var == INSTANCEID:
            return (inst.id for inst in data)
        if var == INDEX:
            return range(len(data))
        col = data.get_column_view(var)[0]
        if not as_string:
            return col
        if var.is_primitive():
            return (var.str_val(val) if not np.isnan(val) else np.nan
                    for val in col)
        else:
            return (str(val) if val else np.nan for val in col)

    @classmethod
    def _get_keymap(cls, data, var, as_string):
        """Return a generator of pairs (key, index) by enumerating and
        switching the values for rows (method `_values`).
        """
        return ((val, i)
                for i, val in enumerate(cls._values(data, var, as_string)))

    def _augment_indices(self, var_data, extra_map, as_string):
        """Compute a two-row array of indices:
        - the first row contains indices for the primary table,
        - the second row contains the matching rows in the extra table or -1"""
        data = self.data
        extra_map = dict(extra_map)
        # Don't match nans. This is needed since numpy supports using nan as
        # keys. If numpy fixes this, the below conditions will always be false,
        # so we're OK again.
        if np.nan in extra_map:
            del extra_map[np.nan]
        keys = (extra_map.get(val, -1)
                for val in self._values(data, var_data, as_string))
        return np.vstack((np.arange(len(data), dtype=np.int64),
                          np.fromiter(keys, dtype=np.int64, count=len(data))))

    def _merge_indices(self, var_data, extra_map, as_string):
        """Use _augment_indices to compute the array of indices,
        then remove those with no match in the second table"""
        augmented = self._augment_indices(var_data, extra_map, as_string)
        return augmented[:, augmented[1] != -1]

    def _combine_indices(self, var_data, extra_map, as_string):
        """Use _augment_indices to compute the array of indices,
        then add rows in the second table without a match in the first"""
        to_add, extra_map = tee(extra_map)
        # dict instead of set because we have pairs; we'll need only keys
        key_map = dict(self._get_keymap(self.data, var_data, as_string))
        # _augment indices will skip rows where the key in the left table
        # is nan. See comment in `_augment_indices` wrt numpy and nan in dicts
        if np.nan in key_map:
            del key_map[np.nan]
        keys = np.fromiter((j for key, j in to_add if key not in key_map),
                           dtype=np.int64)
        right_indices = np.vstack((np.full(len(keys), -1, np.int64), keys))
        return np.hstack((self._augment_indices(var_data, extra_map,
                                                as_string), right_indices))

    def _join_table_by_indices(self, reduced_extra, indices):
        """Join (horizontally) self.data and reduced_extra, taking the pairs
        of rows given in indices"""
        if not len(indices):
            return None
        domain = Orange.data.Domain(*(getattr(self.data.domain, x) +
                                      getattr(reduced_extra.domain, x)
                                      for x in ("attributes", "class_vars",
                                                "metas")))
        X = self._join_array_by_indices(self.data.X, reduced_extra.X, indices)
        Y = self._join_array_by_indices(np.c_[self.data.Y],
                                        np.c_[reduced_extra.Y], indices)
        string_cols = [
            i for i, var in enumerate(domain.metas) if var.is_string
        ]
        metas = self._join_array_by_indices(self.data.metas,
                                            reduced_extra.metas, indices,
                                            string_cols)
        return Orange.data.Table.from_numpy(domain, X, Y, metas)

    @staticmethod
    def _join_array_by_indices(left, right, indices, string_cols=None):
        """Join (horizontally) two arrays, taking pairs of rows given in indices
        """
        tpe = object if object in (left.dtype, right.dtype) else left.dtype
        left_width, right_width = left.shape[1], right.shape[1]
        arr = np.full((indices.shape[1], left_width + right_width), np.nan,
                      tpe)
        if string_cols:
            arr[:, string_cols] = ""
        for indices, to_change, lookup in ((indices[0], arr[:, :left_width],
                                            left), (indices[1],
                                                    arr[:,
                                                        left_width:], right)):
            known = indices != -1
            to_change[known] = lookup[indices[known]]
        return arr