Example #1
0
def run(conn, sql, config, user_namespace):
    if sql.strip():
        for statement in sqlparse.split(sql):
            first_word = sql.strip().split()[0].lower()
            if first_word == 'begin':
                raise Exception("ipython_sql does not support transactions")
            if first_word.startswith('\\') and 'postgres' in str(conn.dialect):
                pgspecial = PGSpecial()
                _, cur, headers, _ = pgspecial.execute(
                                              conn.session.connection.cursor(),
                                              statement)[0]
                result = FakeResultProxy(cur, headers)
            else:
                txt = sqlalchemy.sql.text(statement)
                result = conn.session.execute(txt, user_namespace)
            try:
                # mssql has autocommit
                if config.autocommit and ('mssql' not in str(conn.dialect)):
                    conn.session.execute('commit')
            except sqlalchemy.exc.OperationalError:
                pass # not all engines can commit
            if result and config.feedback:
                print(interpret_rowcount(result.rowcount))
        resultset = ResultSet(result, statement, config)
        if config.autopandas:
            return resultset.DataFrame()
        else:
            return resultset
        #returning only last result, intentionally
    else:
        return 'Connected: %s' % conn.name
Example #2
0
def test_exit_without_active_connection(executor):
    quit_handler = MagicMock()
    pgspecial = PGSpecial()
    pgspecial.register(quit_handler, '\\q', '\\q', 'Quit pgcli.',
                       arg_type=NO_QUERY, case_sensitive=True, aliases=(':q',))

    with patch.object(executor, "conn", BrokenConnection()):
        # we should be able to quit the app, even without active connection
        run(executor, "\\q", pgspecial=pgspecial)
        quit_handler.assert_called_once()

        # an exception should be raised when running a query without active connection
        with pytest.raises(psycopg2.InterfaceError):
            run(executor, "select 1", pgspecial=pgspecial)
Example #3
0
File: main.py Project: avdd/pgcli
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        # Load config.
        c = self.config = get_config(pgclirc_file)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.less_chatty = c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = c['main'].get('prompt', self.default_prompt)
        self.on_error = c['main']['on_error'].upper()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        self.settings = {'casing_file': get_casing_file(c),
          'generate_casing_file': c['main'].as_bool('generate_casing_file'),
          'generate_aliases': c['main'].as_bool('generate_aliases'),
          'asterisk_column_order': c['main']['asterisk_column_order']}
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None
Example #4
0
File: main.py Project: Smotko/pgcli
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)
        
        pgclirc_file = pgclirc_file or '%sconfig' % config_location()

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')

        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        self.on_error = c['main']['on_error'].upper()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None
Example #5
0
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False, pgexecute=None, pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root

        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, "pgclirc")
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c["main"].as_bool("multi_line")
        self.vi_mode = c["main"].as_bool("vi")
        self.pgspecial.timing_enabled = c["main"].as_bool("timing")
        self.table_format = c["main"]["table_format"]
        self.syntax_style = c["main"]["syntax_style"]
        self.cli_style = c["colors"]
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")

        self.on_error = c["main"]["on_error"].upper()

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c["main"].as_bool("smart_completion")
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None
Example #6
0
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None,
                 single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None,
                 auto_vertical_output=False, warn=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
        self.vi_mode = c['main'].as_bool('vi')
        self.auto_expand = auto_vertical_output or c['main'].as_bool(
            'auto_expand')
        self.expanded_output = c['main'].as_bool('expand')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.min_num_menu_lines = c['main'].as_int('min_num_menu_lines')
        self.multiline_continuation_char = c['main']['multiline_continuation_char']
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        c_dest_warning = c['main'].as_bool('destructive_warning')
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = prompt if prompt is not None else c['main'].get('prompt', self.default_prompt)
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c['main']['on_error'].upper()
        self.decimal_format = c['data_formats']['decimal']
        self.float_format = c['data_formats']['float']

        self.pgspecial.pset_pager(self.config['main'].as_bool(
            'enable_pager') and "on" or "off")

        self.style_output = style_factory_output(
            self.syntax_style, c['colors'])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        keyword_casing = c['main']['keyword_casing']
        self.settings = {
            'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'generate_aliases': c['main'].as_bool('generate_aliases'),
            'asterisk_column_order': c['main']['asterisk_column_order'],
            'qualify_columns': c['main']['qualify_columns'],
            'case_column_headers': c['main'].as_bool('case_column_headers'),
            'search_path_filter': c['main'].as_bool('search_path_filter'),
            'single_connection': single_connection,
            'less_chatty': less_chatty,
            'keyword_casing': keyword_casing,
        }

        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None
Example #7
0
class PGCli(object):

    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        self.on_error = c['main']['on_error'].upper()

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(self.change_db, '\\c',
                              '\\c[onnect] database_name',
                              'Change to a new database.',
                              aliases=('use', '\\connect', 'USE'))
        self.pgspecial.register(self.refresh_completions, '\\#', '\\#',
                              'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(self.refresh_completions, '\\refresh', '\\refresh',
                              'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                              'Execute commands from file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
                'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.pgexecute.run(query, self.pgspecial, on_error=self.on_error)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG
                     }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username,
                     uri.port, uri.password)

    def connect(self, database='', host='', user='', port='', passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password', hide_input=True,
                                  show_default=False, type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0]) and
                        auto_passwd_prompt):
                    passwd = click.prompt('Password', hide_input=True,
                                          show_default=False, type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                          sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.refresh_completions()

        self.cli = self._build_cli()

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Initialize default metaquery in case execution fails
                query = MetaQuery(query=document.text, successful=False)

                try:
                    output, query = self._evaluate_command(document.text)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    self.pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        self._handle_server_closed_connection()
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass

                    if self.pgspecial.timing_enabled:
                        print('Time: %0.03fs' % query.total_time)

                    # Check if we need to update completions, in order of most
                    # to least drastic changes
                    if query.db_changed:
                        self.refresh_completions(reset=True)
                    elif query.meta_changed:
                        self.refresh_completions(reset=False)
                    elif query.path_changed:
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(
                                self.pgexecute.search_path())
                        logger.debug('Search path: %r',
                                     self.completer.search_path)

                self.query_history.append(query)

        except EOFError:
            print ('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def _build_cli(self):

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        def prompt_tokens(_):
            return [(Token.Prompt, '%s> ' % self.pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing)

        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
               # Highlight matching brackets while editing.
               ConditionalProcessor(
                   processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                   filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])

        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = PGBuffer(
                always_multiline=self.multi_line,
                completer=self.completer,
                history=FileHistory(os.path.expanduser(history_file)),
                complete_while_typing=Always())

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                ignore_case=True)

            cli = CommandLineInterface(
                application=application,
                eventloop=create_eventloop())

            return cli

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug('sql: %r', text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == 'RESUME'
        res = self.pgexecute.run(text, self.pgspecial,
                                 exception_formatter, on_error_resume)

        for title, cur, headers, status, sql, success in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = 1000
            if (is_select(status) and
                    cur and cur.rowcount > threshold):
                click.secho('The result set has more than %s rows.'
                            % threshold, fg='red')
                if not click.confirm('Do you want to continue?'):
                    click.secho("Aborted!", err=True, fg='red')
                    break

            if self.pgspecial.auto_expand:
                max_width = self.cli.output.get_size().columns
            else:
                max_width = None

            formatted = format_output(
                title, cur, headers, status, self.table_format,
                self.pgspecial.expanded_output, max_width)

            output.extend(formatted)
            end = time()
            total += end - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, meta_changed,
                               db_changed, path_changed, mutated)

        return output, meta_query

    def _handle_server_closed_connection(self):
        """Used during CLI execution"""
        reconnect = click.prompt(
            'Connection reset. Reconnect (Y/n)',
            show_default=False, type=bool, default=True)
        if reconnect:
            try:
                self.pgexecute.connect()
                click.secho('Reconnected!\nTry the command again.', fg='green')
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-SRXF'

        return less_opts

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #8
0
File: main.py Project: dbcli/pgcli
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None,
                 single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None,
                 auto_vertical_output=False, warn=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None
        self.watch_command = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        NamedQueries.instance = NamedQueries.from_config(self.config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
        self.vi_mode = c['main'].as_bool('vi')
        self.auto_expand = auto_vertical_output or c['main'].as_bool(
            'auto_expand')
        self.expanded_output = c['main'].as_bool('expand')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.min_num_menu_lines = c['main'].as_int('min_num_menu_lines')
        self.multiline_continuation_char = c['main']['multiline_continuation_char']
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        c_dest_warning = c['main'].as_bool('destructive_warning')
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = prompt if prompt is not None else c['main'].get('prompt', self.default_prompt)
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c['main']['on_error'].upper()
        self.decimal_format = c['data_formats']['decimal']
        self.float_format = c['data_formats']['float']
        self.initialize_keyring()

        self.pgspecial.pset_pager(self.config['main'].as_bool(
            'enable_pager') and "on" or "off")

        self.style_output = style_factory_output(
            self.syntax_style, c['colors'])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        keyword_casing = c['main']['keyword_casing']
        self.settings = {
            'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'generate_aliases': c['main'].as_bool('generate_aliases'),
            'asterisk_column_order': c['main']['asterisk_column_order'],
            'qualify_columns': c['main']['qualify_columns'],
            'case_column_headers': c['main'].as_bool('case_column_headers'),
            'search_path_filter': c['main'].as_bool('search_path_filter'),
            'single_connection': single_connection,
            'less_chatty': less_chatty,
            'keyword_casing': keyword_casing,
        }

        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.prompt_app = None
Example #9
0
File: main.py Project: dbcli/pgcli
class PGCli(object):

    default_prompt = '\\u@\\h:\\d> '
    max_len_prompt = 30

    def set_default_pager(self, config):
        configured_pager = config['main'].get('pager')
        os_environ_pager = os.environ.get('PAGER')

        if configured_pager:
            self.logger.info(
                'Default pager found in config file: "{}"'.format(configured_pager))
            os.environ['PAGER'] = configured_pager
        elif os_environ_pager:
            self.logger.info('Default pager found in PAGER environment variable: "{}"'.format(
                os_environ_pager))
            os.environ['PAGER'] = os_environ_pager
        else:
            self.logger.info(
                'No default pager found in environment. Using os default pager')

        # Set default set of less recommended options, if they are not already set.
        # They are ignored if pager is different than less.
        if not os.environ.get('LESS'):
            os.environ['LESS'] = '-SRXF'

    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None,
                 single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None,
                 auto_vertical_output=False, warn=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None
        self.watch_command = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        NamedQueries.instance = NamedQueries.from_config(self.config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
        self.vi_mode = c['main'].as_bool('vi')
        self.auto_expand = auto_vertical_output or c['main'].as_bool(
            'auto_expand')
        self.expanded_output = c['main'].as_bool('expand')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.min_num_menu_lines = c['main'].as_int('min_num_menu_lines')
        self.multiline_continuation_char = c['main']['multiline_continuation_char']
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        c_dest_warning = c['main'].as_bool('destructive_warning')
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = prompt if prompt is not None else c['main'].get('prompt', self.default_prompt)
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c['main']['on_error'].upper()
        self.decimal_format = c['data_formats']['decimal']
        self.float_format = c['data_formats']['float']
        self.initialize_keyring()

        self.pgspecial.pset_pager(self.config['main'].as_bool(
            'enable_pager') and "on" or "off")

        self.style_output = style_factory_output(
            self.syntax_style, c['colors'])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        keyword_casing = c['main']['keyword_casing']
        self.settings = {
            'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'generate_aliases': c['main'].as_bool('generate_aliases'),
            'asterisk_column_order': c['main']['asterisk_column_order'],
            'qualify_columns': c['main']['qualify_columns'],
            'case_column_headers': c['main'].as_bool('case_column_headers'),
            'search_path_filter': c['main'].as_bool('search_path_filter'),
            'single_connection': single_connection,
            'less_chatty': less_chatty,
            'keyword_casing': keyword_casing,
        }

        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.prompt_app = None

    def quit(self):
        raise PgCliQuitError

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db, '\\c', '\\c[onnect] database_name',
            'Change to a new database.', aliases=('use', '\\connect', 'USE'))

        refresh_callback = lambda: self.refresh_completions(
            persist_priorities='all')

        self.pgspecial.register(self.quit, '\\q', '\\q',
                                'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=True,
                                aliases=(':q',))
        self.pgspecial.register(self.quit, 'quit', 'quit',
                                'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=False,
                                aliases=('exit',))
        self.pgspecial.register(refresh_callback, '\\#', '\\#',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(refresh_callback, '\\refresh', '\\refresh',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')
        self.pgspecial.register(self.write_to_file, '\\o', '\\o [filename]',
                                'Send all query results to file.')
        self.pgspecial.register(self.info_connection, '\\conninfo',
                                '\\conninfo', 'Get connection details')
        self.pgspecial.register(self.change_table_format, '\\T', '\\T [format]',
                                'Change the table format used to output results')

    def change_table_format(self, pattern, **_):
        try:
            if pattern not in TabularOutputFormatter().supported_formats:
                raise ValueError()
            self.table_format = pattern
            yield (None, None, None,
                   'Changed table format to {}'.format(pattern))
        except ValueError:
            msg = 'Table format {} not recognized. Allowed formats:'.format(
                pattern)
            for table_type in TabularOutputFormatter().supported_formats:
                msg += "\n\t{}".format(table_type)
            msg += '\nCurrently set to: %s' % self.table_format
            yield (None, None, None, msg)

    def info_connection(self, **_):
        if self.pgexecute.host.startswith('/'):
            host = 'socket "%s"' % self.pgexecute.host
        else:
            host = 'host "%s"' % self.pgexecute.host

        yield (None, None, None, 'You are connected to database "%s" as user '
               '"%s" on %s at port "%s".' % (self.pgexecute.dbname,
                                             self.pgexecute.user,
                                             host,
                                             self.pgexecute.port))

    def change_db(self, pattern, **_):
        if pattern:
            # Get all the parameters in pattern, handling double quotes if any.
            infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
            # Now removing quotes.
            list(map(lambda s: s.strip('"'), infos))

            infos.extend([None] * (4 - len(infos)))
            db, user, host, port = infos
            try:
                self.pgexecute.connect(database=db, user=user, host=host,
                                       port=port, **self.pgexecute.extra_args)
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')
                click.echo("Previous connection kept")
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message, '', False, True)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e), '', False, True)]

        if (self.destructive_warning and
                confirm_destructive_query(query) is False):
            message = 'Wise choice. Command execution stopped.'
            return [(None, None, None, message)]

        on_error_resume = (self.on_error == 'RESUME')
        return self.pgexecute.run(
            query, self.pgspecial, on_error_resume=on_error_resume
        )

    def write_to_file(self, pattern, **_):
        if not pattern:
            self.output_file = None
            message = 'File output disabled'
            return [(None, None, None, message, '', True, True)]
        filename = os.path.abspath(os.path.expanduser(pattern))
        if not os.path.isfile(filename):
            try:
                open(filename, 'w').close()
            except IOError as e:
                self.output_file = None
                message = str(e) + '\nFile output disabled'
                return [(None, None, None, message, '', False, True)]
        self.output_file = filename
        message = 'Writing to file "%s"' % self.output_file
        return [(None, None, None, message, '', True, True)]

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        if log_file == 'default':
            log_file = config_location() + 'log'
        ensure_dir_exists(log_file)
        log_level = self.config['main']['log_level']

        # Disable logging if value is NONE by switching to a no-op handler.
        # Set log level to a high value so it doesn't even waste cycles getting called.
        if log_level.upper() == 'NONE':
            handler = logging.NullHandler()
        else:
            handler = logging.FileHandler(os.path.expanduser(log_file))

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG,
                     'NONE': logging.CRITICAL
                     }

        log_level = level_map[log_level.upper()]

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(log_level)

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

        pgspecial_logger = logging.getLogger('pgspecial')
        pgspecial_logger.addHandler(handler)
        pgspecial_logger.setLevel(log_level)

    def initialize_keyring(self):
        global keyring

        keyring_enabled = self.config["main"].as_bool("keyring")
        if keyring_enabled:
            # Try best to load keyring (issue #1041).
            import importlib
            try:
                keyring = importlib.import_module('keyring')
            except Exception as e:  # ImportError for Python 2, ModuleNotFoundError for Python 3
                self.logger.warning('import keyring failed: %r.', e)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        kwargs = psycopg2.extensions.parse_dsn(uri)
        remap = {
            'dbname': 'database',
            'password': '******',
        }
        kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
        self.connect(**kwargs)

    def connect(self, database='', host='', user='', port='', passwd='',
                dsn='', **kwargs):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        kwargs.setdefault('application_name', 'pgcli')

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Find password from store
        key = '%s@%s' % (user, host)
        keyring_error_message = dedent("""\
            {}
            {}
            To remove this message do one of the following:
            - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
            - uninstall keyring: pip uninstall keyring
            - disable keyring in our configuration: add keyring = False to [main]""")
        if not passwd and keyring:

            try:
                passwd = keyring.get_password('pgcli', key)
            except (
                RuntimeError,
                keyring.errors.InitError
            ) as e:
                click.secho(
                    keyring_error_message.format(
                        "Load your password from keyring returned:",
                        str(e)
                    ),
                    err=True,
                    fg='red'
                )

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password for %s' % user, hide_input=True,
                                  show_default=False, type=str)

        def should_ask_for_password(exc):
            # Prompt for a password after 1st attempt to connect
            # fails. Don't prompt if the -w flag is supplied
            if self.never_passwd_prompt:
                return False
            error_msg = utf8tounicode(exc.args[0])
            if "no password supplied" in error_msg:
                return True
            if "password authentication failed" in error_msg:
                return True
            return False

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing or incorrect password, but we're allowed to
        # prompt for a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn,
                                      **kwargs)
            except (OperationalError, InterfaceError) as e:
                if should_ask_for_password(e):
                    passwd = click.prompt('Password for %s' % user,
                                          hide_input=True, show_default=False,
                                          type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn, **kwargs)
                else:
                    raise e
            if passwd and keyring:
                try:
                    keyring.set_password('pgcli', key, passwd)
                except (
                    RuntimeError,
                    keyring.errors.KeyringError,
                ) as e:
                    click.secho(
                        keyring_error_message.format(
                            "Set password in keyring returned:",
                            str(e)
                        ),
                        err=True,
                        fg='red'
                    )

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, text):
        r"""
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param text: Document
        :return: Document
        """
        editor_command = special.editor_command(text)
        while editor_command:
            if editor_command == '\\e':
                filename = special.get_filename(text)
                query = special.get_editor_query(
                    text) or self.get_last_query()
            else:  # \ev or \ef
                filename = None
                spec = text.split()[1]
                if editor_command == '\\ev':
                    query = self.pgexecute.view_definition(spec)
                elif editor_command == '\\ef':
                    query = self.pgexecute.function_definition(spec)
            sql, message = special.open_external_editor(
                filename, sql=query)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            while True:
                try:
                    text = self.prompt_app.prompt(default=sql)
                    break
                except KeyboardInterrupt:
                    sql = ""

            editor_command = special.editor_command(text)
        return text

    def execute_command(self, text):
        logger = self.logger

        query = MetaQuery(query=text, successful=False)

        try:
            if (self.destructive_warning):
                destroy = confirm = confirm_destructive_query(text)
                if destroy is False:
                    click.secho('Wise choice!')
                    raise KeyboardInterrupt
                elif destroy:
                    click.secho('Your call!')
            output, query = self._evaluate_command(text)
        except KeyboardInterrupt:
            # Restart connection to the database
            self.pgexecute.connect()
            logger.debug("cancelled query, sql: %r", text)
            click.secho("cancelled query", err=True, fg='red')
        except NotImplementedError:
            click.secho('Not Yet Implemented.', fg="yellow")
        except OperationalError as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            self._handle_server_closed_connection(text)
        except (PgCliQuitError, EOFError) as e:
            raise
        except Exception as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
        else:
            try:
                if self.output_file and not text.startswith(('\\o ', '\\? ')):
                    try:
                        with open(self.output_file, 'a', encoding='utf-8') as f:
                            click.echo(text, file=f)
                            click.echo('\n'.join(output), file=f)
                            click.echo('', file=f)  # extra newline
                    except IOError as e:
                        click.secho(str(e), err=True, fg='red')
                else:
                    self.echo_via_pager('\n'.join(output))
            except KeyboardInterrupt:
                pass

            if self.pgspecial.timing_enabled:
                # Only add humanized time display if > 1 second
                if query.total_time > 1:
                    print('Time: %0.03fs (%s), executed in: %0.03fs (%s)' % (query.total_time,
                                                                             humanize.time.naturaldelta(
                                                                                 query.total_time),
                                                                             query.execution_time,
                                                                             humanize.time.naturaldelta(query.execution_time)))
                else:
                    print('Time: %0.03fs' % query.total_time)

            # Check if we need to update completions, in order of most
            # to least drastic changes
            if query.db_changed:
                with self._completer_lock:
                    self.completer.reset_completions()
                self.refresh_completions(persist_priorities='keywords')
            elif query.meta_changed:
                self.refresh_completions(persist_priorities='all')
            elif query.path_changed:
                logger.debug('Refreshing search path')
                with self._completer_lock:
                    self.completer.set_search_path(
                        self.pgexecute.search_path())
                logger.debug('Search path: %r',
                             self.completer.search_path)
        return query

    def run_cli(self):
        logger = self.logger

        history_file = self.config['main']['history_file']
        if history_file == 'default':
            history_file = config_location() + 'history'
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history,
                                 persist_priorities='none')

        self.prompt_app = self._build_cli(history)

        if not self.less_chatty:
            print('Server: PostgreSQL', self.pgexecute.server_version)
            print('Version:', __version__)
            print('Chat: https://gitter.im/dbcli/pgcli')
            print('Mail: https://groups.google.com/forum/#!forum/pgcli')
            print('Home: http://pgcli.com')

        try:
            while True:
                try:
                    text = self.prompt_app.prompt()
                except KeyboardInterrupt:
                    continue

                try:
                    text = self.handle_editor_command(text)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Initialize default metaquery in case execution fails
                self.watch_command, timing = special.get_watch_command(text)
                if self.watch_command:
                    while self.watch_command:
                        try:
                            query = self.execute_command(self.watch_command)
                            click.echo(
                                'Waiting for {0} seconds before repeating'
                                .format(timing))
                            sleep(timing)
                        except KeyboardInterrupt:
                            self.watch_command = None
                else:
                    query = self.execute_command(text)

                self.now = dt.datetime.today()

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(text)

                self.query_history.append(query)

        except (PgCliQuitError, EOFError):
            if not self.less_chatty:
                print ('Goodbye!')

    def _build_cli(self, history):
        key_bindings = pgcli_bindings(self)

        def get_message():
            if self.dsn_alias and self.prompt_dsn_format is not None:
                prompt_format = self.prompt_dsn_format
            else:
                prompt_format = self.prompt_format

            prompt = self.get_prompt(prompt_format)

            if (prompt_format == self.default_prompt and
                    len(prompt) > self.max_len_prompt):
                prompt = self.get_prompt('\\d> ')

            return [('class:prompt', prompt)]

        def get_continuation(width, line_number, is_soft_wrap):
            continuation = self.multiline_continuation_char * (width - 1) + ' '
            return [('class:continuation', continuation)]

        get_toolbar_tokens = create_toolbar_tokens_func(self)

        if self.wider_completion_menu:
            complete_style = CompleteStyle.MULTI_COLUMN
        else:
            complete_style = CompleteStyle.COLUMN

        with self._completer_lock:
            prompt_app = PromptSession(
                lexer=PygmentsLexer(PostgresLexer),
                reserve_space_for_menu=self.min_num_menu_lines,
                message=get_message,
                prompt_continuation=get_continuation,
                bottom_toolbar=get_toolbar_tokens,
                complete_style=complete_style,
                input_processors=[
                    # Highlight matching brackets while editing.
                    ConditionalProcessor(
                        processor=HighlightMatchingBracketProcessor(
                            chars='[](){}'),
                        filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
                    # Render \t as 4 spaces instead of "^I"
                    TabsProcessor(char1=' ', char2=' ')],
                auto_suggest=AutoSuggestFromHistory(),
                tempfile_suffix='.sql',
                multiline=pg_is_multiline(self),
                history=history,
                completer=ThreadedCompleter(
                    DynamicCompleter(lambda: self.completer)),
                complete_while_typing=True,
                style=style_factory(self.syntax_style, self.cli_style),
                include_default_pygments_style=False,
                key_bindings=key_bindings,
                enable_open_in_editor=True,
                enable_system_prompt=True,
                enable_suspend=True,
                editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
                search_ignore_case=True)

            return prompt_app

    def _should_show_limit_prompt(self, status, cur):
        """returns True if limit prompt should be shown, False otherwise."""
        if not is_select(status):
            return False
        return self.row_limit > 0 and cur and (cur.rowcount > self.row_limit)

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug('sql: %r', text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0
        execution = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == 'RESUME'
        res = self.pgexecute.run(text, self.pgspecial,
                                 exception_formatter, on_error_resume)

        for title, cur, headers, status, sql, success, is_special in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = self.row_limit
            if self._should_show_limit_prompt(status, cur):
                click.secho('The result set has more than %s rows.'
                            % threshold, fg='red')
                if not click.confirm('Do you want to continue?'):
                    click.secho("Aborted!", err=True, fg='red')
                    break

            if self.pgspecial.auto_expand or self.auto_expand:
                max_width = self.prompt_app.output.get_size().columns
            else:
                max_width = None

            expanded = self.pgspecial.expanded_output or self.expanded_output
            settings = OutputSettings(
                table_format=self.table_format,
                dcmlfmt=self.decimal_format,
                floatfmt=self.float_format,
                missingval=self.null_string,
                expanded=expanded,
                max_width=max_width,
                case_function=(
                    self.completer.case if self.settings['case_column_headers']
                    else lambda x: x
                ),
                style_output=self.style_output
            )
            execution = time() - start
            formatted = format_output(title, cur, headers, status, settings)

            output.extend(formatted)
            total = time() - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, execution, meta_changed,
                               db_changed, path_changed, mutated, is_special)

        return output, meta_query

    def _handle_server_closed_connection(self, text):
        """Used during CLI execution."""
        try:
            click.secho('Reconnecting...', fg='green')
            self.pgexecute.connect()
            click.secho('Reconnected!', fg='green')
            self.execute_command(text)
        except OperationalError as e:
            click.secho('Reconnect Failed', fg='red')
            click.secho(str(e), err=True, fg='red')

    def refresh_completions(self, history=None, persist_priorities='all'):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(self._on_completions_refreshed,
                                     persist_priorities=persist_priorities)
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
            callback, history=history, settings=self.settings)
        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.prompt_app:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.prompt_app.app.invalidate()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object with the newly created completer.

        persist_priorities is a string specifying how the old completer's
        learned prioritizer should be transferred to the new completer.

          'none'     - The new prioritizer is left in a new/clean state

          'all'      - The new prioritizer is updated to exactly reflect
                       the old one

          'keywords' - The new prioritizer is updated with old keyword
                       priorities, but not any other.

        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == 'all':
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == 'keywords':
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == 'none':
                # Leave the new prioritizer as is
                pass
            self.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        # should be before replacing \\d
        string = string.replace('\\dsn_alias', self.dsn_alias or '')
        string = string.replace('\\t', self.now.strftime('%x %X'))
        string = string.replace('\\u', self.pgexecute.user or '(none)')
        string = string.replace('\\H', self.pgexecute.host or '(none)')
        string = string.replace('\\h', self.pgexecute.short_host or '(none)')
        string = string.replace('\\d', self.pgexecute.dbname or '(none)')
        string = string.replace('\\p', str(
            self.pgexecute.port) if self.pgexecute.port is not None else '5432')
        string = string.replace('\\i', str(self.pgexecute.pid) or '(none)')
        string = string.replace('\\#', "#" if (self.pgexecute.superuser) else ">")
        string = string.replace('\\n', "\n")
        return string

    def get_last_query(self):
        """Get the last query executed or None."""
        return self.query_history[-1][0] if self.query_history else None

    def is_too_wide(self, line):
        """Will this line be too wide to fit into terminal?"""
        if not self.prompt_app:
            return False
        return len(COLOR_CODE_REGEX.sub('', line)) > self.prompt_app.output.get_size().columns

    def is_too_tall(self, lines):
        """Are there too many lines to fit into terminal?"""
        if not self.prompt_app:
            return False
        return len(lines) >= (self.prompt_app.output.get_size().rows - 4)

    def echo_via_pager(self, text, color=None):
        if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
            click.echo(text, color=color)
        elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
            lines = text.split('\n')

            # The last 4 lines are reserved for the pgcli menu and padding
            if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
                click.echo_via_pager(text, color=color)
            else:
                click.echo(text, color=color)
        else:
            click.echo_via_pager(text, color)
Example #10
0
class PGCli(object):

    def set_default_pager(self, config):
        configured_pager = config['main'].get('pager')
        os_environ_pager = os.environ.get('PAGER')

        if configured_pager:
            self.logger.info('Default pager found in config file: ' + '\'' + configured_pager + '\'')
            os.environ['PAGER'] = configured_pager
        elif os_environ_pager:
            self.logger.info('Default pager found in PAGER environment variable: ' + '\'' + os_environ_pager + '\'')
            os.environ['PAGER'] = os_environ_pager
        else:
            self.logger.info('No default pager found in environment. Using os default pager')
        # Always set default set of less recommended options, they are ignored if pager is
        # different than less or is already parameterized with their own arguments
        os.environ['LESS'] = '-SRXF'

    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        # Load config.
        c = self.config = get_config(pgclirc_file)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.less_chatty = c['main'].as_bool('less_chatty')

        self.on_error = c['main']['on_error'].upper()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        settings = {'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'asterisk_column_order': c['main']['asterisk_column_order']}
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db, '\\c', '\\c[onnect] database_name',
            'Change to a new database.', aliases=('use', '\\connect', 'USE'))

        refresh_callback = lambda: self.refresh_completions(
            persist_priorities='all')

        self.pgspecial.register(refresh_callback, '\\#', '\\#',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(refresh_callback, '\\refresh', '\\refresh',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')
        self.pgspecial.register(self.write_to_file, '\\o', '\\o [filename]',
                                'Send all query results to file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
                'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message, '', False)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e), '', False)]

        on_error_resume = (self.on_error == 'RESUME')
        return self.pgexecute.run(
            query, self.pgspecial, on_error_resume=on_error_resume
        )

    def write_to_file(self, pattern, **_):
        if not pattern:
            self.output_file = None
            message = 'File output disabled'
            return [(None, None, None, message, '', True)]
        filename = os.path.abspath(os.path.expanduser(pattern))
        if not os.path.isfile(filename):
            try:
                open(filename, 'w').close()
            except IOError as e:
                self.output_file = None
                message = str(e) + '\nFile output disabled'
                return [(None, None, None, message, '', False)]
        self.output_file = filename
        message = 'Writing to file "%s"' % self.output_file
        return [(None, None, None, message, '', True)]

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        if log_file == 'default':
            log_file = config_location() + 'log'
        ensure_dir_exists(log_file)
        log_level = self.config['main']['log_level']

        # Disable logging if value is NONE by switching to a no-op handler.
        # Set log level to a high value so it doesn't even waste cycles getting called.
        if log_level.upper() == 'NONE':
            handler = NullHandler()
        else:
            handler = logging.FileHandler(os.path.expanduser(log_file))

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG,
                     'NONE': logging.CRITICAL
                     }

        log_level = level_map[log_level.upper()]

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(log_level)

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

        pgspecial_logger = logging.getLogger('pgspecial')
        pgspecial_logger.addHandler(handler)
        pgspecial_logger.setLevel(log_level)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        arguments = [database, uri.hostname, uri.username,
                     uri.port, uri.password]
        # unquote each URI part (they may be percent encoded)
        self.connect(*list(map(lambda p: unquote(p) if p else p, arguments)))

    def connect(self, database='', host='', user='', port='', passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password', hide_input=True,
                                  show_default=False, type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0]) and
                        auto_passwd_prompt):
                    passwd = click.prompt('Password', hide_input=True,
                                          show_default=False, type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                          sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        logger = self.logger

        history_file = self.config['main']['history_file']
        if history_file == 'default':
            history_file = config_location() + 'history'
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history,
                                 persist_priorities='none')

        self.cli = self._build_cli(history)

        if not self.less_chatty:
            print('Version:', __version__)
            print('Chat: https://gitter.im/dbcli/pgcli')
            print('Mail: https://groups.google.com/forum/#!forum/pgcli')
            print('Home: http://pgcli.com')

        try:
            while True:
                document = self.cli.run(True)

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Initialize default metaquery in case execution fails
                query = MetaQuery(query=document.text, successful=False)

                try:
                    output, query = self._evaluate_command(document.text)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    self.pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        self._handle_server_closed_connection()
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        if self.output_file and not document.text.startswith(('\\o ', '\\? ')):
                            try:
                                with open(self.output_file, 'a', encoding='utf-8') as f:
                                    click.echo(document.text, file=f)
                                    click.echo('\n'.join(output), file=f)
                                    click.echo('', file=f) # extra newline
                            except IOError as e:
                                click.secho(str(e), err=True, fg='red')
                        else:
                            click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass

                    if self.pgspecial.timing_enabled:
                        # Only add humanized time display if > 1 second
                        if query.total_time > 1:
                            print('Time: %0.03fs (%s)' % (query.total_time,
                                  humanize.time.naturaldelta(query.total_time)))
                        else:
                            print('Time: %0.03fs' % query.total_time)

                    # Check if we need to update completions, in order of most
                    # to least drastic changes
                    if query.db_changed:
                        with self._completer_lock:
                            self.completer.reset_completions()
                        self.refresh_completions(persist_priorities='keywords')
                    elif query.meta_changed:
                        self.refresh_completions(persist_priorities='all')
                    elif query.path_changed:
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(
                                self.pgexecute.search_path())
                        logger.debug('Search path: %r',
                                     self.completer.search_path)

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(document.text)

                self.query_history.append(query)

        except EOFError:
            if not self.less_chatty:
                print ('Goodbye!')

    def _build_cli(self, history):

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        def prompt_tokens(_):
            return [(Token.Prompt, '%s> ' % self.pgexecute.dbname)]

        def get_continuation_tokens(cli, width):
            return [(Token.Continuation, '.' * (width - 1) + ' ')]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(
            lexer=PygmentsLexer(PostgresLexer),
            reserve_space_for_menu=4,
            get_prompt_tokens=prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
               # Highlight matching brackets while editing.
               ConditionalProcessor(
                   processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                   filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])

        with self._completer_lock:
            buf = PGBuffer(
                always_multiline=self.multi_line,
                completer=self.completer,
                history=history,
                complete_while_typing=Always(),
                accept_action=AcceptAction.RETURN_DOCUMENT)

            editing_mode = EditingMode.VI if self.vi_mode else EditingMode.EMACS

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                ignore_case=True,
                editing_mode=editing_mode)

            cli = CommandLineInterface(application=application,
                                       eventloop=self.eventloop)

            return cli

    def _should_show_limit_prompt(self, status, cur):
        """returns True if limit prompt should be shown, False otherwise."""
        if not is_select(status):
            return False
        return self.row_limit > 0 and cur and cur.rowcount > self.row_limit

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug('sql: %r', text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == 'RESUME'
        res = self.pgexecute.run(text, self.pgspecial,
                                 exception_formatter, on_error_resume)

        for title, cur, headers, status, sql, success in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = self.row_limit
            if self._should_show_limit_prompt(status, cur):
                click.secho('The result set has more than %s rows.'
                            % threshold, fg='red')
                if not click.confirm('Do you want to continue?'):
                    click.secho("Aborted!", err=True, fg='red')
                    break

            if self.pgspecial.auto_expand:
                max_width = self.cli.output.get_size().columns
            else:
                max_width = None

            formatted = format_output(
                title, cur, headers, status, self.table_format,
                self.pgspecial.expanded_output, max_width)

            output.extend(formatted)
            end = time()
            total += end - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, meta_changed,
                               db_changed, path_changed, mutated)

        return output, meta_query

    def _handle_server_closed_connection(self):
        """Used during CLI execution"""
        reconnect = click.prompt(
            'Connection reset. Reconnect (Y/n)',
            show_default=False, type=bool, default=True)
        if reconnect:
            try:
                self.pgexecute.connect()
                click.secho('Reconnected!\nTry the command again.', fg='green')
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')

    def refresh_completions(self, history=None, persist_priorities='all'):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(self._on_completions_refreshed,
                                     persist_priorities=persist_priorities)
        c = self.config
        settings = {'casing_file': get_casing_file(c),
          'generate_casing_file': c['main'].as_bool('generate_casing_file'),
          'asterisk_column_order': c['main']['asterisk_column_order']}
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
            callback, history=history, settings=settings)
        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object in cli with the newly created completer.

            persist_priorities is a string specifying how the old completer's
            learned prioritizer should be transferred to the new completer.

              'none'     - The new prioritizer is left in a new/clean state

              'all'      - The new prioritizer is updated to exactly reflect
                           the old one

              'keywords' - The new prioritizer is updated with old keyword
                           priorities, but not any other.
        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == 'all':
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == 'keywords':
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == 'none':
                # Leave the new prioritizer as is
                pass

            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #11
0
class PGCli(object):
    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False, pgexecute=None, pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root

        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, "pgclirc")
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c["main"].as_bool("multi_line")
        self.vi_mode = c["main"].as_bool("vi")
        self.pgspecial.timing_enabled = c["main"].as_bool("timing")
        self.table_format = c["main"]["table_format"]
        self.syntax_style = c["main"]["syntax_style"]
        self.cli_style = c["colors"]
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")

        self.on_error = c["main"]["on_error"].upper()

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c["main"].as_bool("smart_completion")
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db,
            "\\c",
            "\\c[onnect] database_name",
            "Change to a new database.",
            aliases=("use", "\\connect", "USE"),
        )

        refresh_callback = lambda: self.refresh_completions(persist_priorities="all")

        self.pgspecial.register(refresh_callback, "\\#", "\\#", "Refresh auto-completions.", arg_type=NO_QUERY)
        self.pgspecial.register(
            refresh_callback, "\\refresh", "\\refresh", "Refresh auto-completions.", arg_type=NO_QUERY
        )
        self.pgspecial.register(self.execute_from_file, "\\i", "\\i filename", "Execute commands from file.")

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (
            None,
            None,
            None,
            'You are now connected to database "%s" as ' 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
        )

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = "\\i: missing required argument"
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(pattern), encoding="utf-8") as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.pgexecute.run(query, self.pgspecial, on_error=self.on_error)

    def initialize_logging(self):

        log_file = self.config["main"]["log_file"]
        if log_file == "default":
            log_file = config_location() + "log"
        log_level = self.config["main"]["log_level"]

        level_map = {
            "CRITICAL": logging.CRITICAL,
            "ERROR": logging.ERROR,
            "WARNING": logging.WARNING,
            "INFO": logging.INFO,
            "DEBUG": logging.DEBUG,
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            "%(asctime)s (%(process)d/%(threadName)s) " "%(name)s %(levelname)s - %(message)s"
        )

        handler.setFormatter(formatter)

        root_logger = logging.getLogger("pgcli")
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug("Initializing pgcli logging.")
        root_logger.debug("Log file %r.", log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username, uri.port, uri.password)

    def connect(self, database="", host="", user="", port="", passwd="", dsn=""):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get("PGPASSWORD", "")

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt("Password", hide_input=True, show_default=False, type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if "no password supplied" in utf8tounicode(e.args[0]) and auto_passwd_prompt:
                    passwd = click.prompt("Password", hide_input=True, show_default=False, type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port, dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug("Database connection failed: %r.", e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg="red")
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename, sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        history_file = self.config["main"]["history_file"]
        if history_file == "default":
            history_file = config_location() + "history"
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history, persist_priorities="none")

        self.cli = self._build_cli(history)

        print("Version:", __version__)
        print("Chat: https://gitter.im/dbcli/pgcli")
        print("Mail: https://groups.google.com/forum/#!forum/pgcli")
        print("Home: http://pgcli.com")

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg="red")
                    continue

                # Initialize default metaquery in case execution fails
                query = MetaQuery(query=document.text, successful=False)

                try:
                    output, query = self._evaluate_command(document.text)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    self.pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg="red")
                except NotImplementedError:
                    click.secho("Not Yet Implemented.", fg="yellow")
                except OperationalError as e:
                    if "server closed the connection" in utf8tounicode(e.args[0]):
                        self._handle_server_closed_connection()
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg="red")
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg="red")
                else:
                    try:
                        click.echo_via_pager("\n".join(output))
                    except KeyboardInterrupt:
                        pass

                    if self.pgspecial.timing_enabled:
                        # Only add humanized time display if > 1 second
                        if query.total_time > 1:
                            print(
                                "Time: %0.03fs (%s)" % (query.total_time, humanize.time.naturaldelta(query.total_time))
                            )
                        else:
                            print("Time: %0.03fs" % query.total_time)

                    # Check if we need to update completions, in order of most
                    # to least drastic changes
                    if query.db_changed:
                        with self._completer_lock:
                            self.completer.reset_completions()
                        self.refresh_completions(persist_priorities="keywords")
                    elif query.meta_changed:
                        self.refresh_completions(persist_priorities="all")
                    elif query.path_changed:
                        logger.debug("Refreshing search path")
                        with self._completer_lock:
                            self.completer.set_search_path(self.pgexecute.search_path())
                        logger.debug("Search path: %r", self.completer.search_path)

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(document.text)

                self.query_history.append(query)

        except EOFError:
            print("Goodbye!")
        finally:  # Reset the less opts back to original.
            logger.debug("Restoring env var LESS to %r.", original_less_opts)
            os.environ["LESS"] = original_less_opts

    def _build_cli(self, history):
        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(get_vi_mode_enabled=lambda: self.vi_mode, set_vi_mode_enabled=set_vi_mode)

        def prompt_tokens(_):
            return [(Token.Prompt, "%s> " % self.pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode, self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(
            lexer=PygmentsLexer(PostgresLexer),
            reserve_space_for_menu=4,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(chars="[](){}"),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
                )
            ],
        )

        with self._completer_lock:
            buf = PGBuffer(
                always_multiline=self.multi_line,
                completer=self.completer,
                history=history,
                complete_while_typing=Always(),
                accept_action=AcceptAction.RETURN_DOCUMENT,
            )

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                ignore_case=True,
            )

            cli = CommandLineInterface(application=application)

            return cli

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug("sql: %r", text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == "RESUME"
        res = self.pgexecute.run(text, self.pgspecial, exception_formatter, on_error_resume)

        for title, cur, headers, status, sql, success in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = 1000
            if is_select(status) and cur and cur.rowcount > threshold:
                click.secho("The result set has more than %s rows." % threshold, fg="red")
                if not click.confirm("Do you want to continue?"):
                    click.secho("Aborted!", err=True, fg="red")
                    break

            if self.pgspecial.auto_expand:
                max_width = self.cli.output.get_size().columns
            else:
                max_width = None

            formatted = format_output(
                title, cur, headers, status, self.table_format, self.pgspecial.expanded_output, max_width
            )

            output.extend(formatted)
            end = time()
            total += end - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, meta_changed, db_changed, path_changed, mutated)

        return output, meta_query

    def _handle_server_closed_connection(self):
        """Used during CLI execution"""
        reconnect = click.prompt("Connection reset. Reconnect (Y/n)", show_default=False, type=bool, default=True)
        if reconnect:
            try:
                self.pgexecute.connect()
                click.secho("Reconnected!\nTry the command again.", fg="green")
            except OperationalError as e:
                click.secho(str(e), err=True, fg="red")

    def adjust_less_opts(self):
        less_opts = os.environ.get("LESS", "")
        self.logger.debug("Original value for LESS env var: %r", less_opts)
        os.environ["LESS"] = "-SRXF"

        return less_opts

    def refresh_completions(self, history=None, persist_priorities="all"):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(self._on_completions_refreshed, persist_priorities=persist_priorities)
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial, callback, history=history)
        return [(None, None, None, "Auto-completion refresh started in the background.")]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object in cli with the newly created completer.

            persist_priorities is a string specifying how the old completer's
            learned prioritizer should be transferred to the new completer.

              'none'     - The new prioritizer is left in a new/clean state

              'all'      - The new prioritizer is updated to exactly reflect
                           the old one

              'keywords' - The new prioritizer is updated with old keyword
                           priorities, but not any other.
        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == "all":
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == "keywords":
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == "none":
                # Leave the new prioritizer as is
                pass

            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)
Example #12
0
    def __init__(self,
                 force_passwd_prompt=False,
                 never_passwd_prompt=False,
                 pgexecute=None,
                 pgclirc_file=None,
                 row_limit=None,
                 single_connection=False,
                 prompt=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        # Load config.
        c = self.config = get_config(pgclirc_file)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.less_chatty = c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = prompt if prompt is not None else c['main'].get(
            'prompt', self.default_prompt)
        self.on_error = c['main']['on_error'].upper()
        self.decimal_format = c['data_formats']['decimal']
        self.float_format = c['data_formats']['float']
        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        keyword_casing = c['main']['keyword_casing']
        self.settings = {
            'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'generate_aliases': c['main'].as_bool('generate_aliases'),
            'asterisk_column_order': c['main']['asterisk_column_order'],
            'single_connection': single_connection,
            'keyword_casing': keyword_casing,
        }

        completer = PGCompleter(smart_completion,
                                pgspecial=self.pgspecial,
                                settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None
Example #13
0
class PGCli(object):

    default_prompt = '\\u@\\h:\\d> '
    max_len_prompt = 30

    def set_default_pager(self, config):
        configured_pager = config['main'].get('pager')
        os_environ_pager = os.environ.get('PAGER')

        if configured_pager:
            self.logger.info(
                'Default pager found in config file: "{}"'.format(configured_pager))
            os.environ['PAGER'] = configured_pager
        elif os_environ_pager:
            self.logger.info('Default pager found in PAGER environment variable: "{}"'.format(
                os_environ_pager))
            os.environ['PAGER'] = os_environ_pager
        else:
            self.logger.info(
                'No default pager found in environment. Using os default pager')

        # Set default set of less recommended options, if they are not already set.
        # They are ignored if pager is different than less.
        if not os.environ.get('LESS'):
            os.environ['LESS'] = '-SRXF'

    def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
                 pgexecute=None, pgclirc_file=None, row_limit=None,
                 single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None,
                 auto_vertical_output=False, warn=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.multiline_mode = c['main'].get('multi_line_mode', 'psql')
        self.vi_mode = c['main'].as_bool('vi')
        self.auto_expand = auto_vertical_output or c['main'].as_bool(
            'auto_expand')
        self.expanded_output = c['main'].as_bool('expand')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c['main'].as_int('row_limit')

        self.min_num_menu_lines = c['main'].as_int('min_num_menu_lines')
        self.multiline_continuation_char = c['main']['multiline_continuation_char']
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        c_dest_warning = c['main'].as_bool('destructive_warning')
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c['main'].as_bool('less_chatty')
        self.null_string = c['main'].get('null_string', '<null>')
        self.prompt_format = prompt if prompt is not None else c['main'].get('prompt', self.default_prompt)
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c['main']['on_error'].upper()
        self.decimal_format = c['data_formats']['decimal']
        self.float_format = c['data_formats']['float']

        self.pgspecial.pset_pager(self.config['main'].as_bool(
            'enable_pager') and "on" or "off")

        self.style_output = style_factory_output(
            self.syntax_style, c['colors'])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        keyword_casing = c['main']['keyword_casing']
        self.settings = {
            'casing_file': get_casing_file(c),
            'generate_casing_file': c['main'].as_bool('generate_casing_file'),
            'generate_aliases': c['main'].as_bool('generate_aliases'),
            'asterisk_column_order': c['main']['asterisk_column_order'],
            'qualify_columns': c['main']['qualify_columns'],
            'case_column_headers': c['main'].as_bool('case_column_headers'),
            'search_path_filter': c['main'].as_bool('search_path_filter'),
            'single_connection': single_connection,
            'less_chatty': less_chatty,
            'keyword_casing': keyword_casing,
        }

        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial,
            settings=self.settings)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.eventloop = create_eventloop()
        self.cli = None

    def quit(self):
        raise PgCliQuitError

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db, '\\c', '\\c[onnect] database_name',
            'Change to a new database.', aliases=('use', '\\connect', 'USE'))

        refresh_callback = lambda: self.refresh_completions(
            persist_priorities='all')

        self.pgspecial.register(self.quit, '\\q', '\\q',
                                'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=True,
                                aliases=(':q',))
        self.pgspecial.register(self.quit, 'quit', 'quit',
                                'Quit pgcli.', arg_type=NO_QUERY, case_sensitive=False,
                                aliases=('exit',))
        self.pgspecial.register(refresh_callback, '\\#', '\\#',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(refresh_callback, '\\refresh', '\\refresh',
                                'Refresh auto-completions.', arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')
        self.pgspecial.register(self.write_to_file, '\\o', '\\o [filename]',
                                'Send all query results to file.')
        self.pgspecial.register(self.info_connection, '\\conninfo',
                                '\\conninfo', 'Get connection details')
        self.pgspecial.register(self.change_table_format, '\\T', '\\T [format]',
                                'Change the table format used to output results')

    def change_table_format(self, pattern, **_):
        try:
            if pattern not in TabularOutputFormatter().supported_formats:
                raise ValueError()
            self.table_format = pattern
            yield (None, None, None,
                   'Changed table format to {}'.format(pattern))
        except ValueError:
            msg = 'Table format {} not recognized. Allowed formats:'.format(
                pattern)
            for table_type in TabularOutputFormatter().supported_formats:
                msg += "\n\t{}".format(table_type)
            msg += '\nCurrently set to: %s' % self.table_format
            yield (None, None, None, msg)


    def info_connection(self, **_):
        if self.pgexecute.host.startswith('/'):
            host = 'socket "%s"' % self.pgexecute.host
        else:
            host = 'host "%s"' % self.pgexecute.host

        yield (None, None, None, 'You are connected to database "%s" as user '
               '"%s" on %s at port "%s".' % (self.pgexecute.dbname,
                                             self.pgexecute.user,
                                             host,
                                             self.pgexecute.port))

    def change_db(self, pattern, **_):
        if pattern:
            # Get all the parameters in pattern, handling double quotes if any.
            infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
            # Now removing quotes.
            list(map(lambda s: s.strip('"'), infos))

            infos.extend([None] * (4 - len(infos)))
            db, user, host, port = infos
            try:
                self.pgexecute.connect(database=db, user=user, host=host,
                                       port=port, application_name='pgcli')
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')
                click.echo("Previous connection kept")
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message, '', False)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e), '', False)]

        if (self.destructive_warning and
                confirm_destructive_query(query) is False):
            message = 'Wise choice. Command execution stopped.'
            return [(None, None, None, message)]

        on_error_resume = (self.on_error == 'RESUME')
        return self.pgexecute.run(
            query, self.pgspecial, on_error_resume=on_error_resume
        )

    def write_to_file(self, pattern, **_):
        if not pattern:
            self.output_file = None
            message = 'File output disabled'
            return [(None, None, None, message, '', True)]
        filename = os.path.abspath(os.path.expanduser(pattern))
        if not os.path.isfile(filename):
            try:
                open(filename, 'w').close()
            except IOError as e:
                self.output_file = None
                message = str(e) + '\nFile output disabled'
                return [(None, None, None, message, '', False)]
        self.output_file = filename
        message = 'Writing to file "%s"' % self.output_file
        return [(None, None, None, message, '', True)]

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        if log_file == 'default':
            log_file = config_location() + 'log'
        ensure_dir_exists(log_file)
        log_level = self.config['main']['log_level']

        # Disable logging if value is NONE by switching to a no-op handler.
        # Set log level to a high value so it doesn't even waste cycles getting called.
        if log_level.upper() == 'NONE':
            handler = logging.NullHandler()
        else:
            handler = logging.FileHandler(os.path.expanduser(log_file))

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG,
                     'NONE': logging.CRITICAL
                     }

        log_level = level_map[log_level.upper()]

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(log_level)

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

        pgspecial_logger = logging.getLogger('pgspecial')
        pgspecial_logger.addHandler(handler)
        pgspecial_logger.setLevel(log_level)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash

        def fixup_possible_percent_encoding(s):
            return unquote(str(s)) if s else s

        arguments = dict(database=fixup_possible_percent_encoding(database),
                         host=fixup_possible_percent_encoding(uri.hostname),
                         user=fixup_possible_percent_encoding(uri.username),
                         port=fixup_possible_percent_encoding(uri.port),
                         passwd=fixup_possible_percent_encoding(uri.password))
        # Deal with extra params e.g. ?sslmode=verify-ca&sslrootcert=/myrootcert
        if uri.query:
            arguments = dict(
                {k: v for k, (v,) in parse_qs(uri.query).items()},
                **arguments)

        # unquote str(each URI part (they may be percent encoded)
        self.connect(**arguments)

    def connect(self, database='', host='', user='', port='', passwd='',
                dsn='', **kwargs):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Find password from store
        key = '%s@%s' % (user, host)
        if not passwd:
            passwd = keyring.get_password('pgcli', key)

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password for %s' % user, hide_input=True,
                                  show_default=False, type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn,
                                      application_name='pgcli', **kwargs)
                if passwd:
                    keyring.set_password('pgcli', key, passwd)
            except (OperationalError, InterfaceError) as e:
                if ('no password supplied' in utf8tounicode(e.args[0]) and
                        auto_passwd_prompt):
                    passwd = click.prompt('Password for %s' % user,
                                          hide_input=True, show_default=False,
                                          type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn, application_name='pgcli',
                                          **kwargs)
                    if passwd:
                        keyring.set_password('pgcli', key, passwd)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        r"""
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        # FIXME: using application.pre_run_callables like this here is not the best solution.
        # It's internal api of prompt_toolkit that may change. This was added to fix #668.
        # We may find a better way to do it in the future.
        saved_callables = cli.application.pre_run_callables
        try:
            while special.editor_command(document.text):
                filename = special.get_filename(document.text)
                query = (special.get_editor_query(document.text) or
                         self.get_last_query())
                sql, message = special.open_external_editor(
                    filename, sql=query)
                if message:
                    # Something went wrong. Raise an exception and bail.
                    raise RuntimeError(message)
                cli.current_buffer.document = Document(sql,
                                                       cursor_position=len(sql))
                cli.application.pre_run_callables = []
                document = cli.run()
        finally:
            cli.application.pre_run_callables = saved_callables
        return document

    def execute_command(self, text, query):
        logger = self.logger

        try:
            if (self.destructive_warning):
                destroy = confirm = confirm_destructive_query(text)
                if destroy is None:
                    output, query = self._evaluate_command(text)
                elif destroy is True:
                    click.secho('Your call!')
                    output, query = self._evaluate_command(text)
                else:
                    click.secho('Wise choice!')
                    raise KeyboardInterrupt
        except KeyboardInterrupt:
            # Restart connection to the database
            self.pgexecute.connect()
            logger.debug("cancelled query, sql: %r", text)
            click.secho("cancelled query", err=True, fg='red')
        except NotImplementedError:
            click.secho('Not Yet Implemented.', fg="yellow")
        except OperationalError as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            self._handle_server_closed_connection()
        except PgCliQuitError as e:
            raise
        except Exception as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
        else:
            try:
                if self.output_file and not text.startswith(('\\o ', '\\? ')):
                    try:
                        with open(self.output_file, 'a', encoding='utf-8') as f:
                            click.echo(text, file=f)
                            click.echo('\n'.join(output), file=f)
                            click.echo('', file=f)  # extra newline
                    except IOError as e:
                        click.secho(str(e), err=True, fg='red')
                else:
                    self.echo_via_pager('\n'.join(output))
            except KeyboardInterrupt:
                pass

            if self.pgspecial.timing_enabled:
                # Only add humanized time display if > 1 second
                if query.total_time > 1:
                    print('Time: %0.03fs (%s)' % (query.total_time,
                          humanize.time.naturaldelta(query.total_time)))
                else:
                    print('Time: %0.03fs' % query.total_time)

            # Check if we need to update completions, in order of most
            # to least drastic changes
            if query.db_changed:
                with self._completer_lock:
                    self.completer.reset_completions()
                self.refresh_completions(persist_priorities='keywords')
            elif query.meta_changed:
                self.refresh_completions(persist_priorities='all')
            elif query.path_changed:
                logger.debug('Refreshing search path')
                with self._completer_lock:
                    self.completer.set_search_path(
                        self.pgexecute.search_path())
                logger.debug('Search path: %r',
                             self.completer.search_path)
        return query

    def run_cli(self):
        logger = self.logger

        history_file = self.config['main']['history_file']
        if history_file == 'default':
            history_file = config_location() + 'history'
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history,
                                 persist_priorities='none')

        self.cli = self._build_cli(history)

        if not self.less_chatty:
            print('Version:', __version__)
            print('Chat: https://gitter.im/dbcli/pgcli')
            print('Mail: https://groups.google.com/forum/#!forum/pgcli')
            print('Home: http://pgcli.com')

        try:
            while True:
                document = self.cli.run()

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Initialize default metaquery in case execution fails
                query = MetaQuery(query=document.text, successful=False)

                self.watch_command, timing = special.get_watch_command(
                        document.text)
                if self.watch_command:
                    while self.watch_command:
                        try:
                            query = self.execute_command(
                                    self.watch_command, query)
                            click.echo(
                                    'Waiting for {0} seconds before repeating'
                                    .format(timing))
                            sleep(timing)
                        except KeyboardInterrupt:
                            self.watch_command = None
                else:
                    query = self.execute_command(document.text, query)

                self.now = dt.datetime.today()

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(document.text)

                self.query_history.append(query)

        except PgCliQuitError:
            if not self.less_chatty:
                print ('Goodbye!')

    def _build_cli(self, history):

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        def prompt_tokens(_):
            if self.dsn_alias and self.prompt_dsn_format is not None:
                prompt_format = self.prompt_dsn_format
            else:
                prompt_format = self.prompt_format

            prompt = self.get_prompt(prompt_format)

            if (prompt_format == self.default_prompt and
                    len(prompt) > self.max_len_prompt):
                prompt = self.get_prompt('\\d> ')

            return [(Token.Prompt, prompt)]

        def get_continuation_tokens(cli, width):
            continuation=self.multiline_continuation_char * (width - 1) + ' '
            return [(Token.Continuation, continuation)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing,
            self.pgexecute.failed_transaction,
            self.pgexecute.valid_transaction)

        layout = create_prompt_layout(
            lexer=PygmentsLexer(PostgresLexer),
            reserve_space_for_menu=self.min_num_menu_lines,
            get_prompt_tokens=prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
               # Highlight matching brackets while editing.
               ConditionalProcessor(
                   processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                   filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])

        with self._completer_lock:
            buf = PGBuffer(
                auto_suggest=AutoSuggestFromHistory(),
                always_multiline=self.multi_line,
                multiline_mode=self.multiline_mode,
                completer=self.completer,
                history=history,
                complete_while_typing=Always(),
                accept_action=AcceptAction.RETURN_DOCUMENT)

            editing_mode = EditingMode.VI if self.vi_mode else EditingMode.EMACS

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                ignore_case=True,
                editing_mode=editing_mode)

            cli = CommandLineInterface(application=application,
                                       eventloop=self.eventloop)

            return cli

    def _should_show_limit_prompt(self, status, cur):
        """returns True if limit prompt should be shown, False otherwise."""
        if not is_select(status):
            return False
        return self.row_limit > 0 and cur and cur.rowcount > self.row_limit

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug('sql: %r', text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == 'RESUME'
        res = self.pgexecute.run(text, self.pgspecial,
                                 exception_formatter, on_error_resume)

        for title, cur, headers, status, sql, success in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = self.row_limit
            if self._should_show_limit_prompt(status, cur):
                click.secho('The result set has more than %s rows.'
                            % threshold, fg='red')
                if not click.confirm('Do you want to continue?'):
                    click.secho("Aborted!", err=True, fg='red')
                    break

            if self.pgspecial.auto_expand or self.auto_expand:
                max_width = self.cli.output.get_size().columns
            else:
                max_width = None

            expanded = self.pgspecial.expanded_output or self.expanded_output
            settings = OutputSettings(
                table_format=self.table_format,
                dcmlfmt=self.decimal_format,
                floatfmt=self.float_format,
                missingval=self.null_string,
                expanded=expanded,
                max_width=max_width,
                case_function=(
                    self.completer.case if self.settings['case_column_headers']
                    else lambda x: x
                ),
                style_output=self.style_output
            )
            formatted = format_output(title, cur, headers, status, settings)

            output.extend(formatted)
            total = time() - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, meta_changed,
                               db_changed, path_changed, mutated)

        return output, meta_query

    def _handle_server_closed_connection(self):
        """Used during CLI execution"""
        reconnect = click.prompt(
            'Connection reset. Reconnect (Y/n)',
            show_default=False, type=bool, default=True)
        if reconnect:
            try:
                self.pgexecute.connect()
                click.secho('Reconnected!\nTry the command again.', fg='green')
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')

    def refresh_completions(self, history=None, persist_priorities='all'):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(self._on_completions_refreshed,
                                     persist_priorities=persist_priorities)
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
            callback, history=history, settings=self.settings)
        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object in cli with the newly created completer.

            persist_priorities is a string specifying how the old completer's
            learned prioritizer should be transferred to the new completer.

              'none'     - The new prioritizer is left in a new/clean state

              'all'      - The new prioritizer is updated to exactly reflect
                           the old one

              'keywords' - The new prioritizer is updated with old keyword
                           priorities, but not any other.
        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == 'all':
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == 'keywords':
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == 'none':
                # Leave the new prioritizer as is
                pass

            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        # should be before replacing \\d
        string = string.replace('\\dsn_alias', self.dsn_alias or '')
        string = string.replace('\\t', self.now.strftime('%x %X'))
        string = string.replace('\\u', self.pgexecute.user or '(none)')
        host = self.pgexecute.host or '(none)'
        string = string.replace('\\H', host)
        short_host, _, _ = host.partition('.')
        string = string.replace('\\h', short_host)
        string = string.replace('\\d', self.pgexecute.dbname or '(none)')
        string = string.replace('\\p', str(self.pgexecute.port) or '(none)')
        string = string.replace('\\i', str(self.pgexecute.pid) or '(none)')
        string = string.replace('\\#', "#" if (self.pgexecute.superuser) else ">")
        string = string.replace('\\n', "\n")
        return string

    def get_last_query(self):
        """Get the last query executed or None."""
        return self.query_history[-1][0] if self.query_history else None

    def echo_via_pager(self, text, color=None):
        if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
            click.echo(text, color=color)
        else:
            click.echo_via_pager(text, color)
Example #14
0
class PGCli(object):
    def __init__(self,
                 force_passwd_prompt=False,
                 never_passwd_prompt=False,
                 pgexecute=None,
                 pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(self.change_db,
                                '\\c',
                                '\\c[onnect] database_name',
                                'Change to a new database.',
                                aliases=('use', '\\connect', 'USE'))
        self.pgspecial.register(self.refresh_completions,
                                '\\#',
                                '\\#',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(self.refresh_completions,
                                '\\refresh',
                                '\\refresh',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.pgexecute.run(query, self.pgspecial)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {
            'CRITICAL': logging.CRITICAL,
            'ERROR': logging.ERROR,
            'WARNING': logging.WARNING,
            'INFO': logging.INFO,
            'DEBUG': logging.DEBUG
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username, uri.port,
                     uri.password)

    def connect(self,
                database='',
                host='',
                user='',
                port='',
                passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password',
                                  hide_input=True,
                                  show_default=False,
                                  type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0])
                        and auto_passwd_prompt):
                    passwd = click.prompt('Password',
                                          hide_input=True,
                                          show_default=False,
                                          type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                        sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql,
                                                   cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        pgexecute = self.pgexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s> ' % pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode,
            lambda: self.completion_refresher.is_refreshing())

        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])
        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = PGBuffer(always_multiline=self.multi_line,
                           completer=self.completer,
                           history=FileHistory(
                               os.path.expanduser(history_file)),
                           complete_while_typing=Always())

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                            eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in pgexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    # Run the query.
                    start = time()
                    res = pgexecute.run(document.text, self.pgspecial)
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            click.secho(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        if self.pgspecial.auto_expand:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(
                            title, cur, headers, status, self.table_format,
                            self.pgspecial.expanded_output, max_width)
                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    reconnect = True
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        reconnect = click.prompt(
                            'Connection reset. Reconnect (Y/n)',
                            show_default=False,
                            type=bool,
                            default=True)
                        if reconnect:
                            try:
                                pgexecute.connect()
                                click.secho(
                                    'Reconnected!\nTry the command again.',
                                    fg='green')
                            except OperationalError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if self.pgspecial.timing_enabled:
                        print('Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_completions()

                # Refresh search_path to set default schema.
                if need_search_path_refresh(document.text):
                    logger.debug('Refreshing search path')
                    with self._completer_lock:
                        self.completer.set_search_path(pgexecute.search_path())
                    logger.debug('Search path: %r', self.completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-SRXF'

        return less_opts

    def refresh_completions(self):
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                 'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #15
0
def pgspecial():
    return PGSpecial()
Example #16
0
    def __init__(
        self,
        force_passwd_prompt=False,
        never_passwd_prompt=False,
        pgexecute=None,
        pgclirc_file=None,
        row_limit=None,
        single_connection=False,
        less_chatty=None,
        prompt=None,
        prompt_dsn=None,
        auto_vertical_output=False,
        warn=None,
    ):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None
        self.watch_command = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        NamedQueries.instance = NamedQueries.from_config(self.config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c["main"].as_bool("multi_line")
        self.multiline_mode = c["main"].get("multi_line_mode", "psql")
        self.vi_mode = c["main"].as_bool("vi")
        self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
        self.expanded_output = c["main"].as_bool("expand")
        self.pgspecial.timing_enabled = c["main"].as_bool("timing")
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c["main"].as_int("row_limit")

        self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
        self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
        self.table_format = c["main"]["table_format"]
        self.syntax_style = c["main"]["syntax_style"]
        self.cli_style = c["colors"]
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
        c_dest_warning = c["main"].as_bool("destructive_warning")
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
        self.null_string = c["main"].get("null_string", "<null>")
        self.prompt_format = (
            prompt
            if prompt is not None
            else c["main"].get("prompt", self.default_prompt)
        )
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c["main"]["on_error"].upper()
        self.decimal_format = c["data_formats"]["decimal"]
        self.float_format = c["data_formats"]["float"]
        self.initialize_keyring()

        self.pgspecial.pset_pager(
            self.config["main"].as_bool("enable_pager") and "on" or "off"
        )

        self.style_output = style_factory_output(self.syntax_style, c["colors"])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c["main"].as_bool("smart_completion")
        keyword_casing = c["main"]["keyword_casing"]
        self.settings = {
            "casing_file": get_casing_file(c),
            "generate_casing_file": c["main"].as_bool("generate_casing_file"),
            "generate_aliases": c["main"].as_bool("generate_aliases"),
            "asterisk_column_order": c["main"]["asterisk_column_order"],
            "qualify_columns": c["main"]["qualify_columns"],
            "case_column_headers": c["main"].as_bool("case_column_headers"),
            "search_path_filter": c["main"].as_bool("search_path_filter"),
            "single_connection": single_connection,
            "less_chatty": less_chatty,
            "keyword_casing": keyword_casing,
        }

        completer = PGCompleter(
            smart_completion, pgspecial=self.pgspecial, settings=self.settings
        )
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.prompt_app = None
Example #17
0
class PGCli(object):
    default_prompt = "\\u@\\h:\\d> "
    max_len_prompt = 30

    def set_default_pager(self, config):
        configured_pager = config["main"].get("pager")
        os_environ_pager = os.environ.get("PAGER")

        if configured_pager:
            self.logger.info(
                'Default pager found in config file: "%s"', configured_pager
            )
            os.environ["PAGER"] = configured_pager
        elif os_environ_pager:
            self.logger.info(
                'Default pager found in PAGER environment variable: "%s"',
                os_environ_pager,
            )
            os.environ["PAGER"] = os_environ_pager
        else:
            self.logger.info(
                "No default pager found in environment. Using os default pager"
            )

        # Set default set of less recommended options, if they are not already set.
        # They are ignored if pager is different than less.
        if not os.environ.get("LESS"):
            os.environ["LESS"] = "-SRXF"

    def __init__(
        self,
        force_passwd_prompt=False,
        never_passwd_prompt=False,
        pgexecute=None,
        pgclirc_file=None,
        row_limit=None,
        single_connection=False,
        less_chatty=None,
        prompt=None,
        prompt_dsn=None,
        auto_vertical_output=False,
        warn=None,
    ):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute
        self.dsn_alias = None
        self.watch_command = None

        # Load config.
        c = self.config = get_config(pgclirc_file)

        NamedQueries.instance = NamedQueries.from_config(self.config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.output_file = None
        self.pgspecial = PGSpecial()

        self.multi_line = c["main"].as_bool("multi_line")
        self.multiline_mode = c["main"].get("multi_line_mode", "psql")
        self.vi_mode = c["main"].as_bool("vi")
        self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
        self.expanded_output = c["main"].as_bool("expand")
        self.pgspecial.timing_enabled = c["main"].as_bool("timing")
        if row_limit is not None:
            self.row_limit = row_limit
        else:
            self.row_limit = c["main"].as_int("row_limit")

        self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
        self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
        self.table_format = c["main"]["table_format"]
        self.syntax_style = c["main"]["syntax_style"]
        self.cli_style = c["colors"]
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
        c_dest_warning = c["main"].as_bool("destructive_warning")
        self.destructive_warning = c_dest_warning if warn is None else warn
        self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
        self.null_string = c["main"].get("null_string", "<null>")
        self.prompt_format = (
            prompt
            if prompt is not None
            else c["main"].get("prompt", self.default_prompt)
        )
        self.prompt_dsn_format = prompt_dsn
        self.on_error = c["main"]["on_error"].upper()
        self.decimal_format = c["data_formats"]["decimal"]
        self.float_format = c["data_formats"]["float"]
        self.initialize_keyring()

        self.pgspecial.pset_pager(
            self.config["main"].as_bool("enable_pager") and "on" or "off"
        )

        self.style_output = style_factory_output(self.syntax_style, c["colors"])

        self.now = dt.datetime.today()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c["main"].as_bool("smart_completion")
        keyword_casing = c["main"]["keyword_casing"]
        self.settings = {
            "casing_file": get_casing_file(c),
            "generate_casing_file": c["main"].as_bool("generate_casing_file"),
            "generate_aliases": c["main"].as_bool("generate_aliases"),
            "asterisk_column_order": c["main"]["asterisk_column_order"],
            "qualify_columns": c["main"]["qualify_columns"],
            "case_column_headers": c["main"].as_bool("case_column_headers"),
            "search_path_filter": c["main"].as_bool("search_path_filter"),
            "single_connection": single_connection,
            "less_chatty": less_chatty,
            "keyword_casing": keyword_casing,
        }

        completer = PGCompleter(
            smart_completion, pgspecial=self.pgspecial, settings=self.settings
        )
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.prompt_app = None

    def quit(self):
        raise PgCliQuitError

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db,
            "\\c",
            "\\c[onnect] database_name",
            "Change to a new database.",
            aliases=("use", "\\connect", "USE"),
        )

        refresh_callback = lambda: self.refresh_completions(persist_priorities="all")

        self.pgspecial.register(
            self.quit,
            "\\q",
            "\\q",
            "Quit pgcli.",
            arg_type=NO_QUERY,
            case_sensitive=True,
            aliases=(":q",),
        )
        self.pgspecial.register(
            self.quit,
            "quit",
            "quit",
            "Quit pgcli.",
            arg_type=NO_QUERY,
            case_sensitive=False,
            aliases=("exit",),
        )
        self.pgspecial.register(
            refresh_callback,
            "\\#",
            "\\#",
            "Refresh auto-completions.",
            arg_type=NO_QUERY,
        )
        self.pgspecial.register(
            refresh_callback,
            "\\refresh",
            "\\refresh",
            "Refresh auto-completions.",
            arg_type=NO_QUERY,
        )
        self.pgspecial.register(
            self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
        )
        self.pgspecial.register(
            self.write_to_file,
            "\\o",
            "\\o [filename]",
            "Send all query results to file.",
        )
        self.pgspecial.register(
            self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
        )
        self.pgspecial.register(
            self.change_table_format,
            "\\T",
            "\\T [format]",
            "Change the table format used to output results",
        )

    def change_table_format(self, pattern, **_):
        try:
            if pattern not in TabularOutputFormatter().supported_formats:
                raise ValueError()
            self.table_format = pattern
            yield (None, None, None, "Changed table format to {}".format(pattern))
        except ValueError:
            msg = "Table format {} not recognized. Allowed formats:".format(pattern)
            for table_type in TabularOutputFormatter().supported_formats:
                msg += "\n\t{}".format(table_type)
            msg += "\nCurrently set to: %s" % self.table_format
            yield (None, None, None, msg)

    def info_connection(self, **_):
        if self.pgexecute.host.startswith("/"):
            host = 'socket "%s"' % self.pgexecute.host
        else:
            host = 'host "%s"' % self.pgexecute.host

        yield (
            None,
            None,
            None,
            'You are connected to database "%s" as user '
            '"%s" on %s at port "%s".'
            % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
        )

    def change_db(self, pattern, **_):
        if pattern:
            # Get all the parameters in pattern, handling double quotes if any.
            infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
            # Now removing quotes.
            list(map(lambda s: s.strip('"'), infos))

            infos.extend([None] * (4 - len(infos)))
            db, user, host, port = infos
            try:
                self.pgexecute.connect(
                    database=db,
                    user=user,
                    host=host,
                    port=port,
                    **self.pgexecute.extra_args
                )
            except OperationalError as e:
                click.secho(str(e), err=True, fg="red")
                click.echo("Previous connection kept")
        else:
            self.pgexecute.connect()

        yield (
            None,
            None,
            None,
            'You are now connected to database "%s" as '
            'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
        )

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = "\\i: missing required argument"
            return [(None, None, None, message, "", False, True)]
        try:
            with open(os.path.expanduser(pattern), encoding="utf-8") as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e), "", False, True)]

        if self.destructive_warning and confirm_destructive_query(query) is False:
            message = "Wise choice. Command execution stopped."
            return [(None, None, None, message)]

        on_error_resume = self.on_error == "RESUME"
        return self.pgexecute.run(
            query, self.pgspecial, on_error_resume=on_error_resume
        )

    def write_to_file(self, pattern, **_):
        if not pattern:
            self.output_file = None
            message = "File output disabled"
            return [(None, None, None, message, "", True, True)]
        filename = os.path.abspath(os.path.expanduser(pattern))
        if not os.path.isfile(filename):
            try:
                open(filename, "w").close()
            except IOError as e:
                self.output_file = None
                message = str(e) + "\nFile output disabled"
                return [(None, None, None, message, "", False, True)]
        self.output_file = filename
        message = 'Writing to file "%s"' % self.output_file
        return [(None, None, None, message, "", True, True)]

    def initialize_logging(self):

        log_file = self.config["main"]["log_file"]
        if log_file == "default":
            log_file = config_location() + "log"
        ensure_dir_exists(log_file)
        log_level = self.config["main"]["log_level"]

        # Disable logging if value is NONE by switching to a no-op handler.
        # Set log level to a high value so it doesn't even waste cycles getting called.
        if log_level.upper() == "NONE":
            handler = logging.NullHandler()
        else:
            handler = logging.FileHandler(os.path.expanduser(log_file))

        level_map = {
            "CRITICAL": logging.CRITICAL,
            "ERROR": logging.ERROR,
            "WARNING": logging.WARNING,
            "INFO": logging.INFO,
            "DEBUG": logging.DEBUG,
            "NONE": logging.CRITICAL,
        }

        log_level = level_map[log_level.upper()]

        formatter = logging.Formatter(
            "%(asctime)s (%(process)d/%(threadName)s) "
            "%(name)s %(levelname)s - %(message)s"
        )

        handler.setFormatter(formatter)

        root_logger = logging.getLogger("pgcli")
        root_logger.addHandler(handler)
        root_logger.setLevel(log_level)

        root_logger.debug("Initializing pgcli logging.")
        root_logger.debug("Log file %r.", log_file)

        pgspecial_logger = logging.getLogger("pgspecial")
        pgspecial_logger.addHandler(handler)
        pgspecial_logger.setLevel(log_level)

    def initialize_keyring(self):
        global keyring

        keyring_enabled = self.config["main"].as_bool("keyring")
        if keyring_enabled:
            # Try best to load keyring (issue #1041).
            import importlib

            try:
                keyring = importlib.import_module("keyring")
            except Exception as e:  # ImportError for Python 2, ModuleNotFoundError for Python 3
                self.logger.warning("import keyring failed: %r.", e)

    def connect_dsn(self, dsn, **kwargs):
        self.connect(dsn=dsn, **kwargs)

    def connect_uri(self, uri):
        kwargs = psycopg2.extensions.parse_dsn(uri)
        remap = {"dbname": "database", "password": "******"}
        kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
        self.connect(**kwargs)

    def connect(
        self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
    ):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        kwargs.setdefault("application_name", "pgcli")

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get("PGPASSWORD", "")

        # Find password from store
        key = "%s@%s" % (user, host)
        keyring_error_message = dedent(
            """\
            {}
            {}
            To remove this message do one of the following:
            - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
            - uninstall keyring: pip uninstall keyring
            - disable keyring in our configuration: add keyring = False to [main]"""
        )
        if not passwd and keyring:

            try:
                passwd = keyring.get_password("pgcli", key)
            except (RuntimeError, keyring.errors.InitError) as e:
                click.secho(
                    keyring_error_message.format(
                        "Load your password from keyring returned:", str(e)
                    ),
                    err=True,
                    fg="red",
                )

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt(
                "Password for %s" % user, hide_input=True, show_default=False, type=str
            )

        def should_ask_for_password(exc):
            # Prompt for a password after 1st attempt to connect
            # fails. Don't prompt if the -w flag is supplied
            if self.never_passwd_prompt:
                return False
            error_msg = utf8tounicode(exc.args[0])
            if "no password supplied" in error_msg:
                return True
            if "password authentication failed" in error_msg:
                return True
            return False

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing or incorrect password, but we're allowed to
        # prompt for a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
            except (OperationalError, InterfaceError) as e:
                if should_ask_for_password(e):
                    passwd = click.prompt(
                        "Password for %s" % user,
                        hide_input=True,
                        show_default=False,
                        type=str,
                    )
                    pgexecute = PGExecute(
                        database, user, passwd, host, port, dsn, **kwargs
                    )
                else:
                    raise e
            if passwd and keyring:
                try:
                    keyring.set_password("pgcli", key, passwd)
                except (RuntimeError, keyring.errors.KeyringError) as e:
                    click.secho(
                        keyring_error_message.format(
                            "Set password in keyring returned:", str(e)
                        ),
                        err=True,
                        fg="red",
                    )

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug("Database connection failed: %r.", e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg="red")
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, text):
        r"""
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param text: Document
        :return: Document
        """
        editor_command = special.editor_command(text)
        while editor_command:
            if editor_command == "\\e":
                filename = special.get_filename(text)
                query = special.get_editor_query(text) or self.get_last_query()
            else:  # \ev or \ef
                filename = None
                spec = text.split()[1]
                if editor_command == "\\ev":
                    query = self.pgexecute.view_definition(spec)
                elif editor_command == "\\ef":
                    query = self.pgexecute.function_definition(spec)
            sql, message = special.open_external_editor(filename, sql=query)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            while True:
                try:
                    text = self.prompt_app.prompt(default=sql)
                    break
                except KeyboardInterrupt:
                    sql = ""

            editor_command = special.editor_command(text)
        return text

    def execute_command(self, text):
        logger = self.logger

        query = MetaQuery(query=text, successful=False)

        try:
            if self.destructive_warning:
                destroy = confirm = confirm_destructive_query(text)
                if destroy is False:
                    click.secho("Wise choice!")
                    raise KeyboardInterrupt
                elif destroy:
                    click.secho("Your call!")
            output, query = self._evaluate_command(text)
        except KeyboardInterrupt:
            # Restart connection to the database
            self.pgexecute.connect()
            logger.debug("cancelled query, sql: %r", text)
            click.secho("cancelled query", err=True, fg="red")
        except NotImplementedError:
            click.secho("Not Yet Implemented.", fg="yellow")
        except OperationalError as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            self._handle_server_closed_connection(text)
        except (PgCliQuitError, EOFError) as e:
            raise
        except Exception as e:
            logger.error("sql: %r, error: %r", text, e)
            logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg="red")
        else:
            try:
                if self.output_file and not text.startswith(("\\o ", "\\? ")):
                    try:
                        with open(self.output_file, "a", encoding="utf-8") as f:
                            click.echo(text, file=f)
                            click.echo("\n".join(output), file=f)
                            click.echo("", file=f)  # extra newline
                    except IOError as e:
                        click.secho(str(e), err=True, fg="red")
                else:
                    self.echo_via_pager("\n".join(output))
            except KeyboardInterrupt:
                pass

            if self.pgspecial.timing_enabled:
                # Only add humanized time display if > 1 second
                if query.total_time > 1:
                    print(
                        "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
                        % (
                            query.total_time,
                            humanize.time.naturaldelta(query.total_time),
                            query.execution_time,
                            humanize.time.naturaldelta(query.execution_time),
                        )
                    )
                else:
                    print("Time: %0.03fs" % query.total_time)

            # Check if we need to update completions, in order of most
            # to least drastic changes
            if query.db_changed:
                with self._completer_lock:
                    self.completer.reset_completions()
                self.refresh_completions(persist_priorities="keywords")
            elif query.meta_changed:
                self.refresh_completions(persist_priorities="all")
            elif query.path_changed:
                logger.debug("Refreshing search path")
                with self._completer_lock:
                    self.completer.set_search_path(self.pgexecute.search_path())
                logger.debug("Search path: %r", self.completer.search_path)
        return query

    def run_cli(self):
        logger = self.logger

        history_file = self.config["main"]["history_file"]
        if history_file == "default":
            history_file = config_location() + "history"
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history, persist_priorities="none")

        self.prompt_app = self._build_cli(history)

        if not self.less_chatty:
            print("Server: PostgreSQL", self.pgexecute.server_version)
            print("Version:", __version__)
            print("Chat: https://gitter.im/dbcli/pgcli")
            print("Home: http://pgcli.com")

        try:
            while True:
                try:
                    text = self.prompt_app.prompt()
                except KeyboardInterrupt:
                    continue

                try:
                    text = self.handle_editor_command(text)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg="red")
                    continue

                # Initialize default metaquery in case execution fails
                self.watch_command, timing = special.get_watch_command(text)
                if self.watch_command:
                    while self.watch_command:
                        try:
                            query = self.execute_command(self.watch_command)
                            click.echo(
                                "Waiting for {0} seconds before repeating".format(
                                    timing
                                )
                            )
                            sleep(timing)
                        except KeyboardInterrupt:
                            self.watch_command = None
                else:
                    query = self.execute_command(text)

                self.now = dt.datetime.today()

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(text)

                self.query_history.append(query)

        except (PgCliQuitError, EOFError):
            if not self.less_chatty:
                print("Goodbye!")

    def _build_cli(self, history):
        key_bindings = pgcli_bindings(self)

        def get_message():
            if self.dsn_alias and self.prompt_dsn_format is not None:
                prompt_format = self.prompt_dsn_format
            else:
                prompt_format = self.prompt_format

            prompt = self.get_prompt(prompt_format)

            if (
                prompt_format == self.default_prompt
                and len(prompt) > self.max_len_prompt
            ):
                prompt = self.get_prompt("\\d> ")

            return [("class:prompt", prompt)]

        def get_continuation(width, line_number, is_soft_wrap):
            continuation = self.multiline_continuation_char * (width - 1) + " "
            return [("class:continuation", continuation)]

        get_toolbar_tokens = create_toolbar_tokens_func(self)

        if self.wider_completion_menu:
            complete_style = CompleteStyle.MULTI_COLUMN
        else:
            complete_style = CompleteStyle.COLUMN

        with self._completer_lock:
            prompt_app = PromptSession(
                lexer=PygmentsLexer(PostgresLexer),
                reserve_space_for_menu=self.min_num_menu_lines,
                message=get_message,
                prompt_continuation=get_continuation,
                bottom_toolbar=get_toolbar_tokens,
                complete_style=complete_style,
                input_processors=[
                    # Highlight matching brackets while editing.
                    ConditionalProcessor(
                        processor=HighlightMatchingBracketProcessor(chars="[](){}"),
                        filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
                    ),
                    # Render \t as 4 spaces instead of "^I"
                    TabsProcessor(char1=" ", char2=" "),
                ],
                auto_suggest=AutoSuggestFromHistory(),
                tempfile_suffix=".sql",
                # N.b. pgcli's multi-line mode controls submit-on-Enter (which
                # overrides the default behaviour of prompt_toolkit) and is
                # distinct from prompt_toolkit's multiline mode here, which
                # controls layout/display of the prompt/buffer
                multiline=True,
                history=history,
                completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
                complete_while_typing=True,
                style=style_factory(self.syntax_style, self.cli_style),
                include_default_pygments_style=False,
                key_bindings=key_bindings,
                enable_open_in_editor=True,
                enable_system_prompt=True,
                enable_suspend=True,
                editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
                search_ignore_case=True,
            )

            return prompt_app

    def _should_limit_output(self, sql, cur):
        """returns True if the output should be truncated, False otherwise."""
        if not is_select(sql):
            return False

        return (
            not self._has_limit(sql)
            and self.row_limit != 0
            and cur
            and cur.rowcount > self.row_limit
        )

    def _has_limit(self, sql):
        if not sql:
            return False
        return "limit " in sql.lower()

    def _limit_output(self, cur):
        limit = min(self.row_limit, cur.rowcount)
        new_cur = itertools.islice(cur, limit)
        new_status = "SELECT " + str(limit)
        click.secho("The result was limited to %s rows" % limit, fg="red")

        return new_cur, new_status

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug("sql: %r", text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0
        execution = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == "RESUME"
        res = self.pgexecute.run(
            text, self.pgspecial, exception_formatter, on_error_resume
        )

        is_special = None

        for title, cur, headers, status, sql, success, is_special in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)

            if self._should_limit_output(sql, cur):
                cur, status = self._limit_output(cur)

            if self.pgspecial.auto_expand or self.auto_expand:
                max_width = self.prompt_app.output.get_size().columns
            else:
                max_width = None

            expanded = self.pgspecial.expanded_output or self.expanded_output
            settings = OutputSettings(
                table_format=self.table_format,
                dcmlfmt=self.decimal_format,
                floatfmt=self.float_format,
                missingval=self.null_string,
                expanded=expanded,
                max_width=max_width,
                case_function=(
                    self.completer.case
                    if self.settings["case_column_headers"]
                    else lambda x: x
                ),
                style_output=self.style_output,
            )
            execution = time() - start
            formatted = format_output(title, cur, headers, status, settings)

            output.extend(formatted)
            total = time() - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(
            text,
            all_success,
            total,
            execution,
            meta_changed,
            db_changed,
            path_changed,
            mutated,
            is_special,
        )

        return output, meta_query

    def _handle_server_closed_connection(self, text):
        """Used during CLI execution."""
        try:
            click.secho("Reconnecting...", fg="green")
            self.pgexecute.connect()
            click.secho("Reconnected!", fg="green")
            self.execute_command(text)
        except OperationalError as e:
            click.secho("Reconnect Failed", fg="red")
            click.secho(str(e), err=True, fg="red")

    def refresh_completions(self, history=None, persist_priorities="all"):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(
            self._on_completions_refreshed, persist_priorities=persist_priorities
        )
        self.completion_refresher.refresh(
            self.pgexecute,
            self.pgspecial,
            callback,
            history=history,
            settings=self.settings,
        )
        return [
            (None, None, None, "Auto-completion refresh started in the background.")
        ]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.prompt_app:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.prompt_app.app.invalidate()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object with the newly created completer.

        persist_priorities is a string specifying how the old completer's
        learned prioritizer should be transferred to the new completer.

          'none'     - The new prioritizer is left in a new/clean state

          'all'      - The new prioritizer is updated to exactly reflect
                       the old one

          'keywords' - The new prioritizer is updated with old keyword
                       priorities, but not any other.

        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == "all":
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == "keywords":
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == "none":
                # Leave the new prioritizer as is
                pass
            self.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None
            )

    def get_prompt(self, string):
        # should be before replacing \\d
        string = string.replace("\\dsn_alias", self.dsn_alias or "")
        string = string.replace("\\t", self.now.strftime("%x %X"))
        string = string.replace("\\u", self.pgexecute.user or "(none)")
        string = string.replace("\\H", self.pgexecute.host or "(none)")
        string = string.replace("\\h", self.pgexecute.short_host or "(none)")
        string = string.replace("\\d", self.pgexecute.dbname or "(none)")
        string = string.replace(
            "\\p",
            str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
        )
        string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
        string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
        string = string.replace("\\n", "\n")
        return string

    def get_last_query(self):
        """Get the last query executed or None."""
        return self.query_history[-1][0] if self.query_history else None

    def is_too_wide(self, line):
        """Will this line be too wide to fit into terminal?"""
        if not self.prompt_app:
            return False
        return (
            len(COLOR_CODE_REGEX.sub("", line))
            > self.prompt_app.output.get_size().columns
        )

    def is_too_tall(self, lines):
        """Are there too many lines to fit into terminal?"""
        if not self.prompt_app:
            return False
        return len(lines) >= (self.prompt_app.output.get_size().rows - 4)

    def echo_via_pager(self, text, color=None):
        if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
            click.echo(text, color=color)
        elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
            lines = text.split("\n")

            # The last 4 lines are reserved for the pgcli menu and padding
            if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
                click.echo_via_pager(text, color=color)
            else:
                click.echo(text, color=color)
        else:
            click.echo_via_pager(text, color)
Example #18
0
File: main.py Project: ajlai/pgcli
class PGCli(object):
    def set_default_pager(self, config):
        configured_pager = config['main'].get('pager')
        os_environ_pager = os.environ.get('PAGER')

        if configured_pager:
            self.logger.info('Default pager found in config file: ' + '\'' +
                             configured_pager + '\'')
            os.environ['PAGER'] = configured_pager
        elif (os_environ_pager):
            self.logger.info(
                'Default pager found in PAGER environment variable: ' + '\'' +
                os_environ_pager + '\'')
            os.environ['PAGER'] = os_environ_pager
        else:
            self.logger.info(
                'No default pager found in environment. Using os default pager'
            )
        # Always set default set of less recommended options, they are ignored if pager is
        # different than less or is already parameterized with their own arguments
        os.environ['LESS'] = '-SRXF'

    def __init__(self,
                 force_passwd_prompt=False,
                 never_passwd_prompt=False,
                 pgexecute=None,
                 pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.set_default_pager(c)
        self.pgspecial = PGSpecial()

        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')

        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        self.on_error = c['main']['on_error'].upper()

        self.completion_refresher = CompletionRefresher()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(self.change_db,
                                '\\c',
                                '\\c[onnect] database_name',
                                'Change to a new database.',
                                aliases=('use', '\\connect', 'USE'))

        refresh_callback = lambda: self.refresh_completions(persist_priorities=
                                                            'all')

        self.pgspecial.register(refresh_callback,
                                '\\#',
                                '\\#',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(refresh_callback,
                                '\\refresh',
                                '\\refresh',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message, '', False)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e), '', False)]

        on_error_resume = (self.on_error == 'RESUME')
        return self.pgexecute.run(query,
                                  self.pgspecial,
                                  on_error_resume=on_error_resume)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        if log_file == 'default':
            log_file = config_location() + 'log'
        ensure_dir_exists(log_file)
        log_level = self.config['main']['log_level']

        level_map = {
            'CRITICAL': logging.CRITICAL,
            'ERROR': logging.ERROR,
            'WARNING': logging.WARNING,
            'INFO': logging.INFO,
            'DEBUG': logging.DEBUG
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username, uri.port,
                     uri.password)

    def connect(self,
                database='',
                host='',
                user='',
                port='',
                passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password',
                                  hide_input=True,
                                  show_default=False,
                                  type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0])
                        and auto_passwd_prompt):
                    passwd = click.prompt('Password',
                                          hide_input=True,
                                          show_default=False,
                                          type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                        sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql,
                                                   cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        logger = self.logger

        history_file = self.config['main']['history_file']
        if history_file == 'default':
            history_file = config_location() + 'history'
        history = FileHistory(os.path.expanduser(history_file))
        self.refresh_completions(history=history, persist_priorities='none')

        self.cli = self._build_cli(history)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Initialize default metaquery in case execution fails
                query = MetaQuery(query=document.text, successful=False)

                try:
                    output, query = self._evaluate_command(document.text)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    self.pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        self._handle_server_closed_connection()
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass

                    if self.pgspecial.timing_enabled:
                        # Only add humanized time display if > 1 second
                        if query.total_time > 1:
                            print(
                                'Time: %0.03fs (%s)' %
                                (query.total_time,
                                 humanize.time.naturaldelta(query.total_time)))
                        else:
                            print('Time: %0.03fs' % query.total_time)

                    # Check if we need to update completions, in order of most
                    # to least drastic changes
                    if query.db_changed:
                        with self._completer_lock:
                            self.completer.reset_completions()
                        self.refresh_completions(persist_priorities='keywords')
                    elif query.meta_changed:
                        self.refresh_completions(persist_priorities='all')
                    elif query.path_changed:
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(
                                self.pgexecute.search_path())
                        logger.debug('Search path: %r',
                                     self.completer.search_path)

                # Allow PGCompleter to learn user's preferred keywords, etc.
                with self._completer_lock:
                    self.completer.extend_query_history(document.text)

                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')

    def _build_cli(self, history):
        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        def prompt_tokens(_):
            return [(Token.Prompt, '%s> ' % self.pgexecute.dbname)]

        def get_continuation_tokens(cli, width):
            return [(Token.Continuation, '.' * (width - 1) + ' ')]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(
            lexer=PygmentsLexer(PostgresLexer),
            reserve_space_for_menu=4,
            get_prompt_tokens=prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])

        with self._completer_lock:
            buf = PGBuffer(always_multiline=self.multi_line,
                           completer=self.completer,
                           history=history,
                           complete_while_typing=Always(),
                           accept_action=AcceptAction.RETURN_DOCUMENT)

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                ignore_case=True)

            cli = CommandLineInterface(application=application)

            return cli

    def _evaluate_command(self, text):
        """Used to run a command entered by the user during CLI operation
        (Puts the E in REPL)

        returns (results, MetaQuery)
        """
        logger = self.logger
        logger.debug('sql: %r', text)

        all_success = True
        meta_changed = False  # CREATE, ALTER, DROP, etc
        mutated = False  # INSERT, DELETE, etc
        db_changed = False
        path_changed = False
        output = []
        total = 0

        # Run the query.
        start = time()
        on_error_resume = self.on_error == 'RESUME'
        res = self.pgexecute.run(text, self.pgspecial, exception_formatter,
                                 on_error_resume)

        for title, cur, headers, status, sql, success in res:
            logger.debug("headers: %r", headers)
            logger.debug("rows: %r", cur)
            logger.debug("status: %r", status)
            threshold = 1000
            if (is_select(status) and cur and cur.rowcount > threshold):
                click.secho('The result set has more than %s rows.' %
                            threshold,
                            fg='red')
                if not click.confirm('Do you want to continue?'):
                    click.secho("Aborted!", err=True, fg='red')
                    break

            if self.pgspecial.auto_expand:
                max_width = self.cli.output.get_size().columns
            else:
                max_width = None

            formatted = format_output(title, cur, headers, status,
                                      self.table_format,
                                      self.pgspecial.expanded_output,
                                      max_width)

            output.extend(formatted)
            end = time()
            total += end - start

            # Keep track of whether any of the queries are mutating or changing
            # the database
            if success:
                mutated = mutated or is_mutating(status)
                db_changed = db_changed or has_change_db_cmd(sql)
                meta_changed = meta_changed or has_meta_cmd(sql)
                path_changed = path_changed or has_change_path_cmd(sql)
            else:
                all_success = False

        meta_query = MetaQuery(text, all_success, total, meta_changed,
                               db_changed, path_changed, mutated)

        return output, meta_query

    def _handle_server_closed_connection(self):
        """Used during CLI execution"""
        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                 show_default=False,
                                 type=bool,
                                 default=True)
        if reconnect:
            try:
                self.pgexecute.connect()
                click.secho('Reconnected!\nTry the command again.', fg='green')
            except OperationalError as e:
                click.secho(str(e), err=True, fg='red')

    def refresh_completions(self, history=None, persist_priorities='all'):
        """ Refresh outdated completions

        :param history: A prompt_toolkit.history.FileHistory object. Used to
                        load keyword and identifier preferences

        :param persist_priorities: 'all' or 'keywords'
        """

        callback = functools.partial(self._on_completions_refreshed,
                                     persist_priorities=persist_priorities)
        self.completion_refresher.refresh(self.pgexecute,
                                          self.pgspecial,
                                          callback,
                                          history=history)
        return [(None, None, None,
                 'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer, persist_priorities):
        self._swap_completer_objects(new_completer, persist_priorities)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer, persist_priorities):
        """Swap the completer object in cli with the newly created completer.

            persist_priorities is a string specifying how the old completer's
            learned prioritizer should be transferred to the new completer.

              'none'     - The new prioritizer is left in a new/clean state

              'all'      - The new prioritizer is updated to exactly reflect
                           the old one

              'keywords' - The new prioritizer is updated with old keyword
                           priorities, but not any other.
        """
        with self._completer_lock:
            old_completer = self.completer
            self.completer = new_completer

            if persist_priorities == 'all':
                # Just swap over the entire prioritizer
                new_completer.prioritizer = old_completer.prioritizer
            elif persist_priorities == 'keywords':
                # Swap over the entire prioritizer, but clear name priorities,
                # leaving learned keyword priorities alone
                new_completer.prioritizer = old_completer.prioritizer
                new_completer.prioritizer.clear_names()
            elif persist_priorities == 'none':
                # Leave the new prioritizer as is
                pass

            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)