示例#1
0
    def enable_disable_gui(self, layer):
        """Enable/Disable menu options based on selected layer"""
        self.actionInitDB.setEnabled(False)
        self.layerMenu.setEnabled(False)
        self.actionInitLayer.setEnabled(False)
        self.actionLayerUpdate.setEnabled(False)
        self.actionLayerLoad.setEnabled(False)

        selectedLayer = self.iface.activeLayer()
        if selectedLayer:
            provider = layer.dataProvider()

            if provider.name() == "postgres":
                self.actionInitDB.setEnabled(True)
                uri = QgsDataSourceURI(provider.dataSourceUri())
                execute = SQLExecute(self.iface, self.iface.mainWindow(), uri)
                historised = execute.check_if_historised(
                    uri.schema(),
                    self.iface.activeLayer().name())
                db_initialized = execute.db_initialize_check(uri.schema())

                if db_initialized:
                    self.actionInitDB.setEnabled(False)
                    self.layerMenu.setEnabled(True)
                else:
                    self.layerMenu.setEnabled(False)

                if historised:
                    self.actionLayerUpdate.setEnabled(True)
                    self.actionLayerLoad.setEnabled(True)
                else:
                    self.actionInitLayer.setEnabled(True)
示例#2
0
    def initialize_layer(self):
        """Use Layer info and run init() .sql query"""
        selectedLayer = self.iface.activeLayer()
        provider = selectedLayer.dataProvider()
        uri = QgsDataSourceURI(provider.dataSourceUri())
        conn = self.dbconn.connect_to_DB(uri)

        if conn is False:
            return

        result = QMessageBox.warning(
            self.iface.mainWindow(), self.tr(u"Initialize Layer"),
            self.tr(u"Are you sure you wish to proceed?"),
            QMessageBox.No | QMessageBox.Yes)
        if result == QMessageBox.Yes:
            # Get SQL vars
            hasGeometry = selectedLayer.hasGeometryType()
            schema = uri.schema()
            table = uri.table()

            execute = SQLExecute(self.iface, self.iface.mainWindow(), uri)
            success, msg = execute.Init_hist_tabs(hasGeometry, schema, table)
            if success:
                QMessageBox.warning(
                    self.iface.mainWindow(), self.tr(u"Success"),
                    self.tr(u"Layer successfully initialized!"))
            else:
                QMessageBox.warning(self.iface.mainWindow(), self.tr(u"Error"),
                                    self.tr(u"Initialization failed!\n" + msg))
            self.enable_disable_gui(selectedLayer)
        else:
            return
示例#3
0
class SelectDateDialog(QDialog, Ui_SelectDate):
    """
    Class responsible for handling the date selection.
    """
    def __init__(self, iface):
        QDialog.__init__(self)
        self.setupUi(self)
        self.iface = iface
        self.dbconn = DBConn(iface)
        self.cmbDate.clear()
        self.get_dates()

    def get_dates(self):
        """Get all historized dates from layer"""
        provider = self.iface.activeLayer().dataProvider()
        self.uri = QgsDataSourceURI(provider.dataSourceUri())
        self.conn = self.dbconn.connect_to_DB(self.uri)
        # self.schema = self.uri.schema()
        self.execute = SQLExecute(self.iface, self.iface.mainWindow(),
                                  self.uri)
        dateList = self.execute.retrieve_all_table_versions(
            self.iface.activeLayer().name(), self.uri.schema())

        if not dateList:
            self.records = False
            QMessageBox.warning(self.iface.mainWindow(), self.tr(u"Error"),
                                self.tr(u"No historized versions found!"))
        else:
            for date in dateList:
                self.cmbDate.addItem(str(date[0]))
                self.records = True

    @pyqtSignature("")
    def on_buttonBox_accepted(self):
        """
        Run hist_tabs.version() SQL-function
        """
        if self.records:
            self.execute.get_older_table_version(
                self.uri.schema(),
                self.iface.activeLayer().name(), self.cmbDate.currentText(),
                self.uri)

    @pyqtSignature("")
    def on_buttonBox_rejected(self):
        """
        Close Window and DB-Connection
        """
        self.conn.close()
        self.close()
示例#4
0
	def _bg_refresh(self, sqlexecute, callbacks, completer_options):
		completer = SQLCompleter(**completer_options)

		e = sqlexecute
		if e.dbname == ":memory:":
			# if DB is memory, needed to use same connection
			executor = sqlexecute
		else:
			# Create a new sqlexecute method to popoulate the completions.
			executor = SQLExecute(e.dbname)

		# If callbacks is a single function then push it into a list.
		if callable(callbacks):
			callbacks = [callbacks]

		while 1:
			for refresher in self.refreshers.values():
				refresher(completer, executor)
				if self._restart_refresh.is_set():
					self._restart_refresh.clear()
					break
			else:
				# Break out of while loop if the for loop finishes natually
				# without hitting the break statement.
				break

			# Start over the refresh from the beginning if the for loop hit the
			# break statement.
			continue

		for callback in callbacks:
			callback(completer)
示例#5
0
    def set_importable_tables(self):
        """Populates combobox with importable tables

           Name Structure <schema.tablename>"""
        self.uri = QgsDataSourceURI(self.provider.dataSourceUri())
        self.execute = SQLExecute(
            self.iface, self.iface.mainWindow(), self.uri)

        # Returns Result to be parsed
        schemaTableList = self.execute.retrieve_all_importable_tables()
        cmbList = list()
        if schemaTableList:
            for entry in schemaTableList:
                cmbEntry = entry[0] + '.' + entry[1]
                cmbList.append(cmbEntry)
        self.cmbImportTable.addItems(cmbList)
示例#6
0
    def get_dates(self):
        """Get all historized dates from layer"""
        provider = self.iface.activeLayer().dataProvider()
        self.uri = QgsDataSourceURI(provider.dataSourceUri())
        self.conn = self.dbconn.connect_to_DB(self.uri)
        # self.schema = self.uri.schema()
        self.execute = SQLExecute(self.iface, self.iface.mainWindow(),
                                  self.uri)
        dateList = self.execute.retrieve_all_table_versions(
            self.iface.activeLayer().name(), self.uri.schema())

        if not dateList:
            self.records = False
            QMessageBox.warning(self.iface.mainWindow(), self.tr(u"Error"),
                                self.tr(u"No historized versions found!"))
        else:
            for date in dateList:
                self.cmbDate.addItem(str(date[0]))
                self.records = True
示例#7
0
class ImportUpdateDialog(QDialog, Ui_ImportUpdate):
    """
    Class responsible for handling the update() function.
    """
    def __init__(self, iface):
        QDialog.__init__(self)
        self.setupUi(self)
        self.iface = iface
        self.dbconn = DBConn(iface)
        self.get_attributes()
        self.set_importable_tables()

    def get_attributes(self):
        """Place excludable attributes in model list"""
        # Model Structure used from Tutorial at
        # http://pythoncentral.io/pyside-pyqt-tutorial-qlistview-and-qstandarditemmodel/

        self.model = QStandardItemModel(self.listTblAttrib)
        self.hasGeometry = self.iface.activeLayer().hasGeometryType()
        self.provider = self.iface.activeLayer().dataProvider()
        fields = self.iface.activeLayer().pendingFields()

        for field in fields:
            item = QStandardItem(field.name())
            item.setCheckable(True)
            self.model.appendRow(item)
        self.listTblAttrib.setModel(self.model)

    def get_checked_attributes(self):
        """Creates a list of checked elements to be excluded"""
        i = 0
        exclList = list()
        while self.model.item(i):
            if self.model.item(i).checkState():
                exclList.append(self.model.item(i).text())
            i += 1
        return exclList

    def set_importable_tables(self):
        """Populates combobox with importable tables

           Name Structure <schema.tablename>"""
        self.uri = QgsDataSourceURI(self.provider.dataSourceUri())
        self.execute = SQLExecute(
            self.iface, self.iface.mainWindow(), self.uri)

        # Returns Result to be parsed
        schemaTableList = self.execute.retrieve_all_importable_tables()
        cmbList = list()
        if schemaTableList:
            for entry in schemaTableList:
                cmbEntry = entry[0] + '.' + entry[1]
                cmbList.append(cmbEntry)
        self.cmbImportTable.addItems(cmbList)

    @pyqtSignature("")
    def on_buttonBox_accepted(self):
        """
        Run hist_tabs.update() SQL function with given parameters
        """
        select = self.cmbImportTable.currentText()
        exclList = self.get_checked_attributes()
        # Split the user selection for querying
        splitString = select.split('.')
        importSchema = splitString[0]
        importTable = splitString[1]
        self.execute.update_table_entries(
            importSchema,
            importTable,
            self.uri.schema(),
            self.iface.activeLayer().name(),
            self.hasGeometry, exclList)

    @pyqtSignature("")
    def on_buttonBox_rejected(self):
        """
        Slot documentation goes here.
        """
        self.close()
示例#8
0
class LiteCli(object):

    default_prompt = "\\d> "
    max_len_prompt = 45

    def __init__(
        self,
        sqlexecute=None,
        prompt=None,
        logfile=None,
        auto_vertical_output=False,
        warn=None,
        liteclirc=None,
    ):
        self.sqlexecute = sqlexecute
        self.logfile = logfile

        # Load config.
        c = self.config = get_config(liteclirc)

        self.multi_line = c["main"].as_bool("multi_line")
        self.key_bindings = c["main"]["key_bindings"]
        special.set_favorite_queries(self.config)
        self.formatter = TabularOutputFormatter(
            format_name=c["main"]["table_format"])
        self.formatter.litecli = self
        self.syntax_style = c["main"]["syntax_style"]
        self.less_chatty = c["main"].as_bool("less_chatty")
        self.cli_style = c["colors"]
        self.output_style = style_factory_output(self.syntax_style,
                                                 self.cli_style)
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
        c_dest_warning = c["main"].as_bool("destructive_warning")
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.login_path_as_host = c["main"].as_bool("login_path_as_host")

        # read from cli argument or user config file
        self.auto_vertical_output = auto_vertical_output or c["main"].as_bool(
            "auto_vertical_output")

        # audit log
        if self.logfile is None and "audit_log" in c["main"]:
            try:
                self.logfile = open(os.path.expanduser(c["main"]["audit_log"]),
                                    "a")
            except (IOError, OSError):
                self.echo(
                    "Error: Unable to open the audit log file. Your queries will not be logged.",
                    err=True,
                    fg="red",
                )
                self.logfile = False

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"]
        self.prompt_format = (prompt or prompt_cnf or c["main"]["prompt"]
                              or self.default_prompt)
        self.prompt_continuation_format = c["main"]["prompt_continuation"]
        keyword_casing = c["main"].get("keyword_casing", "auto")

        self.query_history = []

        # Initialize completer.
        self.completer = SQLCompleter(
            supported_formats=self.formatter.supported_formats,
            keyword_casing=keyword_casing,
        )
        self._completer_lock = threading.Lock()

        # Register custom special commands.
        self.register_special_commands()

        self.prompt_app = None

    def register_special_commands(self):
        special.register_special_command(
            self.change_db,
            ".open",
            ".open",
            "Change to a new database.",
            aliases=("use", "\\u"),
        )
        special.register_special_command(
            self.refresh_completions,
            "rehash",
            "\\#",
            "Refresh auto-completions.",
            arg_type=NO_QUERY,
            aliases=("\\#", ),
        )
        special.register_special_command(
            self.change_table_format,
            ".mode",
            "\\T",
            "Change the table format used to output results.",
            aliases=("tableformat", "\\T"),
            case_sensitive=True,
        )
        special.register_special_command(
            self.execute_from_file,
            "source",
            "\\. filename",
            "Execute commands from file.",
            aliases=("\\.", ),
        )
        special.register_special_command(
            self.change_prompt_format,
            "prompt",
            "\\R",
            "Change prompt format.",
            aliases=("\\R", ),
            case_sensitive=True,
        )

    def change_table_format(self, arg, **_):
        try:
            self.formatter.format_name = arg
            yield (None, None, None, "Changed table format to {}".format(arg))
        except ValueError:
            msg = "Table format {} not recognized. Allowed formats:".format(
                arg)
            for table_type in self.formatter.supported_formats:
                msg += "\n\t{}".format(table_type)
            yield (None, None, None, msg)

    def change_db(self, arg, **_):
        if arg is None:
            self.sqlexecute.connect()
        else:
            self.sqlexecute.connect(database=arg)

        self.refresh_completions()
        yield (
            None,
            None,
            None,
            'You are now connected to database "%s"' %
            (self.sqlexecute.dbname),
        )

    def execute_from_file(self, arg, **_):
        if not arg:
            message = "Missing required argument, filename."
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(arg), encoding="utf-8") as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        if self.destructive_warning and confirm_destructive_query(
                query) is False:
            message = "Wise choice. Command execution stopped."
            return [(None, None, None, message)]

        return self.sqlexecute.run(query)

    def change_prompt_format(self, arg, **_):
        """
		Change the prompt format.
		"""
        if not arg:
            message = "Missing required argument, format."
            return [(None, None, None, message)]

        self.prompt_format = self.get_prompt(arg)
        return [(None, None, None, "Changed prompt format to %s" % arg)]

    def initialize_logging(self):

        log_file = self.config["main"]["log_file"]
        if log_file == "default":
            log_file = config_location() + "log"
        ensure_dir_exists(log_file)

        log_level = self.config["main"]["log_level"]

        level_map = {
            "CRITICAL": logging.CRITICAL,
            "ERROR": logging.ERROR,
            "WARNING": logging.WARNING,
            "INFO": logging.INFO,
            "DEBUG": logging.DEBUG,
        }

        # Disable logging if value is NONE by switching to a no-op handler
        # Set log level to a high value so it doesn't even waste cycles getting called.
        if log_level.upper() == "NONE":
            handler = logging.NullHandler()
            log_level = "CRITICAL"
        elif dir_path_exists(log_file):
            handler = logging.FileHandler(log_file)
        else:
            self.echo(
                'Error: Unable to open the log file "{}".'.format(log_file),
                err=True,
                fg="red",
            )
            return

        formatter = logging.Formatter(
            "%(asctime)s (%(process)d/%(threadName)s) "
            "%(name)s %(levelname)s - %(message)s")

        handler.setFormatter(formatter)

        root_logger = logging.getLogger("litecli")
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        logging.captureWarnings(True)

        root_logger.debug("Initializing litecli logging.")
        root_logger.debug("Log file %r.", log_file)

    def read_my_cnf_files(self, keys):
        """
		Reads a list of config files and merges them. The last one will win.
		:param files: list of files to read
		:param keys: list of keys to retrieve
		:returns: tuple, with None for missing keys.
		"""
        cnf = self.config

        sections = ["main"]

        def get(key):
            result = None
            for sect in cnf:
                if sect in sections and key in cnf[sect]:
                    result = cnf[sect][key]
            return result

        return {x: get(x) for x in keys}

    def connect(self, database=""):

        cnf = {"database": None}

        cnf = self.read_my_cnf_files(cnf.keys())

        # Fall back to config values only if user did not specify a value.

        database = database or cnf["database"]

        # Connect to the database.

        def _connect():
            self.sqlexecute = SQLExecute(database)

        try:
            _connect()
        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug("Database connection failed: %r.", e)
            #self.logger.error("traceback: %r", traceback.format_exc())
            self.echo(str(e), err=True, fg="red")
            exit(1)

    def handle_editor_command(self, text):
        """Editor command is any query that is prefixed or suffixed by a '\e'.
		The reason for a while loop is because a user might edit a query
		multiple times. For eg:

		"select * from \e"<enter> to edit it in vim, then come
		back to the prompt with the edited query "select * from
		blah where q = 'abc'\e" to edit it again.
		:param text: Document
		:return: Document

		"""

        while special.editor_command(text):
            filename = special.get_filename(text)
            query = special.get_editor_query(text) or self.get_last_query()
            sql, message = special.open_external_editor(filename, sql=query)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            while True:
                try:
                    text = self.prompt_app.prompt(default=sql)
                    break
                except KeyboardInterrupt:
                    sql = ""

            continue
        return text

    def run_cli(self):
        iterations = 0
        sqlexecute = self.sqlexecute
        logger = self.logger
        self.configure_pager()
        self.refresh_completions()

        history_file = config_location() + "history"
        if dir_path_exists(history_file):
            history = FileHistory(history_file)
        else:
            history = None
            self.echo(
                'Error: Unable to open the history file "{}". '
                "Your query history will not be saved.".format(history_file),
                err=True,
                fg="red",
            )

        key_bindings = cli_bindings(self)

        if 0 and not self.less_chatty:
            print("Version:", __version__)
            print(
                "Mail: https://groups.google.com/forum/#!forum/litecli-users")
            print("Github: https://github.com/dbcli/litecli")
            # print("Home: https://litecli.com")

        def get_message():
            prompt = self.get_prompt(self.prompt_format)
            if (self.prompt_format == self.default_prompt
                    and len(prompt) > self.max_len_prompt):
                prompt = self.get_prompt("\\d> ")
            #print(prompt)
            return [("class:prompt", '%s> ' % DBNAME)]

        def get_continuation(width, line_number, is_soft_wrap):
            continuation = " " * (width - 1) + " "
            return [("class:continuation", continuation)]

        def show_suggestion_tip():
            return iterations < 2

        def one_iteration(text=None):

            if text is None:
                try:
                    text = self.prompt_app.prompt()

                except KeyboardInterrupt:
                    return

                special.set_expanded_output(False)

                try:
                    text = self.handle_editor_command(text)
                except RuntimeError as e:
                    #logger.error("sql: %r, error: %r", text, e)
                    #logger.error("traceback: %r", traceback.format_exc())
                    self.showerr(e)
                    return

            if not text.strip():
                return

            if self.destructive_warning:
                destroy = confirm_destructive_query(text)
                if destroy is None:
                    pass  # Query was not destructive. Nothing to do here.
                elif destroy is True:
                    self.echo("Your call!")
                else:
                    self.echo("Wise choice!")
                    return

            # Keep track of whether or not the query is mutating. In case
            # of a multi-statement query, the overall query is considered
            # mutating if any one of the component statements is mutating
            mutating = False

            try:
                logger.debug("sql: %r", text)

                special.write_tee(self.get_prompt(self.prompt_format) + text)
                if self.logfile:
                    self.logfile.write("\n# %s\n" % datetime.now())
                    self.logfile.write(text)
                    self.logfile.write("\n")

                successful = False
                start = time()
                res = sqlexecute.run(text, DBNAME)
                self.formatter.query = text
                successful = True
                result_count = 0
                for title, cur, headers, status in res:
                    logger.debug("headers: %r", headers)
                    logger.debug("rows: %r", cur)
                    logger.debug("status: %r", status)
                    threshold = 1000
                    if is_select(status) and cur and cur.rowcount > threshold:
                        self.echo(
                            "The result set has more than {} rows.".format(
                                threshold),
                            fg="red",
                        )
                        if not confirm("Do you want to continue?"):
                            self.echo("Aborted!", err=True, fg="red")
                            break

                    if self.auto_vertical_output:
                        max_width = self.prompt_app.output.get_size().columns
                    else:
                        max_width = None

                    formatted = self.format_output(
                        title, cur, headers, special.is_expanded_output(),
                        max_width)

                    t = time() - start
                    try:
                        if result_count > 0:
                            self.echo("")
                        try:
                            self.output(formatted, status)
                        except KeyboardInterrupt:
                            pass
                        self.echo("Time: %0.03fs" % t)
                    except KeyboardInterrupt:
                        pass

                    start = time()
                    result_count += 1
                    mutating = mutating or is_mutating(status)
                special.unset_once_if_written()
            except EOFError as e:
                raise e
            except KeyboardInterrupt:
                # get last connection id
                connection_id_to_kill = sqlexecute.connection_id
                logger.debug("connection id to kill: %r",
                             connection_id_to_kill)
                # Restart connection to the database
                sqlexecute.connect()
                try:
                    for title, cur, headers, status in sqlexecute.run(
                            "kill %s" % connection_id_to_kill):
                        status_str = str(status).lower()
                        if status_str.find("ok") > -1:
                            logger.debug(
                                "cancelled query, connection id: %r, sql: %r",
                                connection_id_to_kill,
                                text,
                            )
                            self.echo("cancelled query", err=True, fg="red")
                except Exception as e:
                    self.echo(
                        "Encountered error while cancelling query: {}".format(
                            e),
                        err=True,
                        fg="red",
                    )
            except NotImplementedError:
                self.echo("Not Yet Implemented.", fg="yellow")
            except OperationalError as e:
                logger.debug("Exception: %r", e)
                if e.args[0] in (2003, 2006, 2013):
                    logger.debug("Attempting to reconnect.")
                    self.echo("Reconnecting...", fg="yellow")
                    try:
                        sqlexecute.connect()
                        logger.debug("Reconnected successfully.")
                        one_iteration(text)
                        return  # OK to just return, cuz the recursion call runs to the end.
                    except OperationalError as e:
                        logger.debug("Reconnect failed. e: %r", e)
                        self.echo(str(e), err=True, fg="red")
                        # If reconnection failed, don't proceed further.
                        return
                else:
                    #logger.error("sql: %r, error: %r", text, e)
                    #logger.error("traceback: %r", traceback.format_exc())
                    self.showerr(e)
            except Exception as e:
                #logger.error("sql: %r, error: %r", text, e)
                #logger.error("traceback: %r", traceback.format_exc())
                self.showerr(e)
            else:
                if is_dropping_database(text, self.sqlexecute.dbname):
                    self.sqlexecute.dbname = None
                    self.sqlexecute.connect()

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(text):
                    self.refresh_completions(reset=need_completion_reset(text))
            finally:
                if self.logfile is False:
                    self.echo("Warning: This query was not logged.",
                              err=True,
                              fg="red")
            query = Query(text, successful, mutating)
            self.query_history.append(query)

        get_toolbar_tokens = create_toolbar_tokens_func(
            self, show_suggestion_tip)

        if self.wider_completion_menu:
            complete_style = CompleteStyle.MULTI_COLUMN
        else:
            complete_style = CompleteStyle.COLUMN

        with self._completer_lock:

            if self.key_bindings == "vi":
                editing_mode = EditingMode.VI
            else:
                editing_mode = EditingMode.EMACS

            self.prompt_app = PromptSession(
                lexer=PygmentsLexer(LiteCliLexer),
                reserve_space_for_menu=self.get_reserved_space(),
                message=get_message,
                prompt_continuation=get_continuation,
                bottom_toolbar=get_toolbar_tokens,
                complete_style=complete_style,
                input_processors=[
                    ConditionalProcessor(
                        processor=HighlightMatchingBracketProcessor(
                            chars="[](){}"),
                        filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
                    )
                ],
                tempfile_suffix=".sql",
                completer=DynamicCompleter(lambda: self.completer),
                history=history,
                auto_suggest=AutoSuggestFromHistory(),
                complete_while_typing=True,
                multiline=cli_is_multiline(self),
                style=style_factory(self.syntax_style, self.cli_style),
                include_default_pygments_style=False,
                key_bindings=key_bindings,
                enable_open_in_editor=True,
                enable_system_prompt=True,
                enable_suspend=True,
                editing_mode=editing_mode,
                search_ignore_case=True,
            )

        try:
            while True:
                one_iteration()
                iterations += 1
        except EOFError:
            special.close_tee()
            if not self.less_chatty:
                self.echo("Goodbye!")

    def showerr(self, e):
        self.echo('##', err=True, fg="red")
        self.echo('##' + str(e), err=True, fg="red")
        self.echo('##', err=True, fg="red")

    def log_output(self, output):
        """Log the output in the audit log, if it's enabled."""
        if self.logfile:
            click.echo(utf8tounicode(output), file=self.logfile)

    def echo(self, s, **kwargs):
        """Print a message to stdout.

		The message will be logged in the audit log, if enabled.

		All keyword arguments are passed to click.echo().

		"""
        self.log_output(s)
        click.secho(s, **kwargs)

    def get_output_margin(self, status=None):
        """Get the output margin (number of rows for the prompt, footer and
		timing message."""
        margin = (self.get_reserved_space() +
                  self.get_prompt(self.prompt_format).count("\n") + 2)
        if status:
            margin += 1 + status.count("\n")

        return margin

    def output(self, output, status=None):
        """Output text to stdout or a pager command.

		The status text is not outputted to pager or files.

		The message will be logged in the audit log, if enabled. The
		message will be written to the tee file, if enabled. The
		message will be written to the output file, if enabled.

		"""
        if output:
            size = self.prompt_app.output.get_size()

            margin = self.get_output_margin(status)

            fits = True
            buf = []
            output_via_pager = self.explicit_pager and special.is_pager_enabled(
            )
            for i, line in enumerate(output, 1):
                self.log_output(line)
                special.write_tee(line)
                special.write_once(line)

                if fits or output_via_pager:
                    # buffering
                    buf.append(line)
                    if len(line) > size.columns or i > (size.rows - margin):
                        fits = False
                        if not self.explicit_pager and special.is_pager_enabled(
                        ):
                            # doesn't fit, use pager
                            output_via_pager = True

                        if not output_via_pager:
                            # doesn't fit, flush buffer
                            for line in buf:
                                click.secho(line)
                            buf = []
                else:
                    click.secho(line)

            if buf:
                if output_via_pager:
                    # sadly click.echo_via_pager doesn't accept generators
                    click.echo_via_pager("\n".join(buf))
                else:
                    for line in buf:
                        click.secho(line)

        if status:
            self.log_output(status)
            click.secho(status)

    def configure_pager(self):
        # Provide sane defaults for less if they are empty.
        if not os.environ.get("LESS"):
            os.environ["LESS"] = "-RXF"

        cnf = self.read_my_cnf_files(["pager", "skip-pager"])
        if cnf["pager"]:
            special.set_pager(cnf["pager"])
            self.explicit_pager = True
        else:
            self.explicit_pager = False

        if cnf["skip-pager"] or not self.config["main"].as_bool(
                "enable_pager"):
            special.disable_pager()

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(
            self.sqlexecute,
            self._on_completions_refreshed,
            {
                "supported_formats": self.formatter.supported_formats,
                "keyword_casing": self.completer.keyword_casing,
            },
        )

        return [(None, None, None,
                 "Auto-completion refresh started in the background.")]

    def _on_completions_refreshed(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
		"""
        with self._completer_lock:
            self.completer = new_completer

        if self.prompt_app:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.prompt_app.app.invalidate()

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        self.logger.debug("Getting prompt")
        sqlexecute = self.sqlexecute
        now = datetime.now()
        string = string.replace("\\d", sqlexecute.dbname or "(none)")
        string = string.replace("\\n", "\n")
        string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
        string = string.replace("\\m", now.strftime("%M"))
        string = string.replace("\\P", now.strftime("%p"))
        string = string.replace("\\R", now.strftime("%H"))
        string = string.replace("\\r", now.strftime("%I"))
        string = string.replace("\\s", now.strftime("%S"))
        string = string.replace("\\_", " ")
        return string

    def run_query(self, query, new_line=True):
        """Runs *query*."""
        results = self.sqlexecute.run(query)
        for result in results:
            title, cur, headers, status = result
            self.formatter.query = query
            output = self.format_output(title, cur, headers)
            for line in output:
                click.echo(line, nl=new_line)

    def format_output(self,
                      title,
                      cur,
                      headers,
                      expanded=False,
                      max_width=None):
        expanded = expanded or self.formatter.format_name == "vertical"
        output = []

        output_kwargs = {
            "dialect": "unix",
            "disable_numparse": True,
            "preserve_whitespace": True,
            "preprocessors": (preprocessors.align_decimals, ),
            "style": self.output_style,
        }

        if title:  # Only print the title if it's not None.
            output = itertools.chain(output, [title])

        if cur:
            column_types = None
            if hasattr(cur, "description"):

                def get_col_type(col):
                    # col_type = FIELD_TYPES.get(col[1], text_type)
                    # return col_type if type(col_type) is type else text_type
                    return text_type

                column_types = [get_col_type(col) for col in cur.description]

            if max_width is not None:
                cur = list(cur)

            formatted = self.formatter.format_output(
                cur,
                headers,
                format_name="vertical" if expanded else None,
                column_types=column_types,
                **output_kwargs)

            if isinstance(formatted, (text_type)):
                formatted = formatted.splitlines()
            formatted = iter(formatted)

            first_line = next(formatted)
            formatted = itertools.chain([first_line], formatted)

            if (not expanded and max_width and headers and cur
                    and len(first_line) > max_width):
                formatted = self.formatter.format_output(
                    cur,
                    headers,
                    format_name="vertical",
                    column_types=column_types,
                    **output_kwargs)
                if isinstance(formatted, (text_type)):
                    formatted = iter(formatted.splitlines())

            output = itertools.chain(output, formatted)

        return output

    def get_reserved_space(self):
        """Get the number of lines to reserve for the completion menu."""
        reserved_space_ratio = 0.45
        max_reserved_space = 8
        _, height = click.get_terminal_size()
        return min(int(round(height * reserved_space_ratio)),
                   max_reserved_space)

    def get_last_query(self):
        """Get the last query executed or None."""
        return self.query_history[-1][0] if self.query_history else None
示例#9
0
 def _connect():
     self.sqlexecute = SQLExecute(database)
示例#10
0
文件: vcli.py 项目: pie-crust/etl
	def _connect(self,database):
		#click.echo('asdasd')
		self.sqlexecute = SQLExecute(database, True)