def format_output(title, cur, headers, status, settings): output = [] expanded = (settings.expanded or settings.table_format == 'vertical') table_format = ('vertical' if settings.expanded else settings.table_format) max_width = settings.max_width case_function = settings.case_function formatter = TabularOutputFormatter(format_name=table_format) def format_array(val): if val is None: return settings.missingval if not isinstance(val, list): return val return '{' + ','.join(text_type(format_array(e)) for e in val) + '}' def format_arrays(data, headers, **_): data = list(data) for row in data: row[:] = [ format_array(val) if isinstance(val, list) else val for val in row ] return data, headers output_kwargs = { 'sep_title': 'RECORD {n}', 'sep_character': '-', 'sep_length': (1, 25), 'missing_value': settings.missingval, 'integer_format': settings.dcmlfmt, 'float_format': settings.floatfmt, 'preprocessors': (format_numbers, format_arrays), 'disable_numparse': True, 'preserve_whitespace': True } if not settings.floatfmt: output_kwargs['preprocessors'] = (align_decimals, ) if title: # Only print the title if it's not None. output.append(title) if cur: headers = [case_function(utf8tounicode(x)) for x in headers] rows = list(cur) formatted = formatter.format_output(rows, headers, **output_kwargs) first_line = formatted[:formatted.find('\n')] if (not expanded and max_width and len(first_line) > max_width and headers): formatted = formatter.format_output( rows, headers, format_name='vertical', **output_kwargs) output.append(formatted) if status: # Only print the status if it's not None. output.append(status) return output
def test_unsupported_format(): """Test that TabularOutputFormatter rejects unknown formats.""" formatter = TabularOutputFormatter() with pytest.raises(ValueError): formatter.format_name = 'foobar' with pytest.raises(ValueError): formatter.format_output((), (), format_name='foobar')
def format_output(self, data): """Format data. :param data: raw data get from ES :return: formatted output, it's either table or vertical format """ formatter = TabularOutputFormatter(format_name=self.table_format) # parse response data datarows = data["datarows"] schema = data["schema"] total_hits = data["total"] cur_size = data["size"] # unused data for now, fields = [] types = [] # get header and type as lists, for future usage for i in schema: fields.append(i["name"]) types.append(i["type"]) output = formatter.format_output(datarows, fields, **self.output_kwargs) output_message = "fetched rows / total rows = %d/%d" % (cur_size, total_hits) # Open Distro for ES sql has a restriction of retrieving 200 rows of data by default if total_hits > 200 == cur_size: output_message += "\n" + "Attention: Use LIMIT keyword when retrieving more than 200 rows of data" # check width overflow, change format_name for better visual effect first_line = next(output) output = itertools.chain([output_message], [first_line], output) if len(first_line) > self.max_width: click.secho(message="Output longer than terminal width", fg="red") if click.confirm( "Do you want to display data vertically for better visual effect?" ): output = formatter.format_output(datarows, fields, format_name="vertical", **self.output_kwargs) output = itertools.chain([output_message], output) # TODO: if decided to add row_limit. Refer to pgcli -> main -> line 866. return output
def test_check_failing_unicode_rendering(self, mode): """ Check internal assumption that some rendering unicode table rendering modes fails in Windows console. """ table_formatter = TabularOutputFormatter(mode) with pytest.raises(UnicodeEncodeError): click.echo(table_formatter.format_output( ((1, 87), (2, 80), (3, 79)), ('day', 'temperature')))
def test_all_text_type(extra_kwargs): """Test the TabularOutputFormatter class.""" data = [[1, "", None, Decimal(2)]] headers = ["col1", "col2", "col3", "col4"] output_formatter = TabularOutputFormatter() for format_name in output_formatter.supported_formats: for row in output_formatter.format_output( iter(data), headers, format_name=format_name, **extra_kwargs ): assert isinstance(row, text_type), "not unicode for {}".format(format_name)
def test_all_text_type(): """Test the TabularOutputFormatter class.""" data = [[1, u"", None, Decimal(2)]] headers = ['col1', 'col2', 'col3', 'col4'] output_formatter = TabularOutputFormatter() for format_name in output_formatter.supported_formats: for row in output_formatter.format_output(iter(data), headers, format_name=format_name): assert isinstance( row, text_type), "not unicode for {}".format(format_name)
def table(self, name): """Show name and type of fields for table 'name'""" string = "========================================================\n" string += "========= Table %-24s : ==========\n" % name string += "========================================================\n" formatter = TabularOutputFormatter(format_name="simple") request = self._sql.columns() records = self._db.request(request, (name, ), ask=True) headers = ["Name", "Default", "Nullable", "Type"] for x in formatter.format_output(records, headers): string += x + "\n" return string
def test_enforce_iterable(): """Test that all output formatters accept iterable""" formatter = TabularOutputFormatter() loremipsum = 'lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod'.split( ' ') for format_name in formatter.supported_formats: formatter.format_name = format_name try: formatted = next( formatter.format_output(zip(loremipsum), ['lorem'])) except TypeError: assert False, "{0} doesn't return iterable".format(format_name)
def types(self): """Show types list and values""" string = "========================================================\n" string += "====================== Types =======================\n" string += "========================================================\n" formatter = TabularOutputFormatter(format_name="simple") request = self._sql.types() records = self._db.request(request, ask=True) data = list() for record in records: data.append(record[1:3]) headers = ["Name", "Enums values"] for x in formatter.format_output(data, headers): string += x + "\n" return string
def message(self): """Property message reader""" formatter = TabularOutputFormatter(format_name="simple") self._message = f"=== Liste des enregistrements de la " \ f"table '{self._table}' ===\n\n" headers, *records = self._db.request(self._request, ask=True, with_headers=True) if self._inline: for n, record in enumerate(records): line = f"== resultat => enregistrement {n + 1}: =>\n" for i, field in enumerate(record): if Reader.MATCH_MEMORY.match(str(field)): field = "-byte data-" line += "\t%-12s: %s \n" % (headers[i],str(field)) self._message += line + "\n" else: for x in formatter.format_output(records, headers): self._message += x + "\n" return self._message
def main(): formatter = TabularOutputFormatter() while True: city = prompt('输入城市名称: ', completer=ColorCompleter(), complete_style=CompleteStyle.MULTI_COLUMN) if city == "exit": break if city not in city_code_dict: print('输入错误,请重新输入') continue # 设置header样式 class HeaderStyle(Style): default_style = "" styles = { Token.Output.Header: '#00ff5f bold', } code = city_code_dict[city] url = "http://www.weather.com.cn/weather1d/{code}.shtml".format( code=code) res = crawl_weather(url) headers = ['城市', '时间', "天气", "气温", "风况"] data = [] for item in res: data.append( [city, item['header'], item['wea'], item['tem'], item['win']]) res = formatter.format_output(data, headers, format_name='ascii', style=HeaderStyle) for item in res: print(item)
def run_cli(self): cursor = self.conn.cursor() formatter = TabularOutputFormatter(format_name='psql') history = InMemoryHistory() def get_prompt_message(): layout = '{host}/{database}> ' cursor.execute('SELECT current_database()') result = cursor.fetchone() return layout.format(host=self.host, database=result[0]) while True: try: statement = prompt(get_prompt_message(), lexer=ImpalaLexer, history=history) except EOFError: click.echo('GoodBye!') self.conn.close() break try: cursor.execute(statement) except Exception as e: print(str(e)) continue if cursor.description is not None: data = cursor.fetchall() header = (item[0] for item in cursor.description) output = formatter.format_output(data=data, headers=header) for line in output: click.echo(line)
class MyCli(object): default_prompt = '\\t \\u@\\h:\\d> ' max_len_prompt = 45 defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. cnf_files = [ '/etc/my.cnf', '/etc/mysql/my.cnf', '/usr/local/etc/my.cnf', '~/.my.cnf' ] system_config_files = [ '/etc/myclirc', ] default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') def __init__(self, sqlexecute=None, prompt=None, logfile=None, defaults_suffix=None, defaults_file=None, login_path=None, auto_vertical_output=False, warn=None, myclirc="~/.myclirc"): self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. # If defaults_file is specified then override the class variable with # defaults_file. if defaults_file: self.cnf_files = [defaults_file] # Load config. config_files = ([self.default_config_file] + self.system_config_files + [myclirc]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) sql_format.register_new_formatter(self.formatter) self.formatter.mycli = self self.syntax_style = c['main']['syntax_style'] self.less_chatty = c['main'].as_bool('less_chatty') self.cli_style = c['colors'] self.output_style = style_factory(self.syntax_style, self.cli_style) self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') c_dest_warning = c['main'].as_bool('destructive_warning') self.destructive_warning = c_dest_warning if warn is None else warn self.login_path_as_host = c['main'].as_bool('login_path_as_host') # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or \ c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files: write_default_config(self.default_config_file, myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: try: self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') except (IOError, OSError) as e: self.echo( 'Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red') self.logfile = False self.completion_refresher = CompletionRefresher() self.logger = logging.getLogger(__name__) self.initialize_logging() prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt self.prompt_continuation_format = c['main']['prompt_continuation'] keyword_casing = c['main'].get('keyword_casing', 'auto') self.query_history = [] # Initialize completer. self.smart_completion = c['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, supported_formats=self.formatter.supported_formats, keyword_casing=keyword_casing) self._completer_lock = threading.Lock() # Register custom special commands. self.register_special_commands() # Load .mylogin.cnf if it exists. mylogin_cnf_path = get_mylogin_cnf_path() if mylogin_cnf_path: mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path) if mylogin_cnf_path and mylogin_cnf: # .mylogin.cnf gets read last, even if defaults_file is specified. self.cnf_files.append(mylogin_cnf) elif mylogin_cnf_path and not mylogin_cnf: # There was an error reading the login path file. print('Error: Unable to read login path file.') self.cli = None def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u', )) special.register_special_command( self.change_db, 'connect', '\\r', 'Reconnect to the database. Optional database argument.', aliases=('\\r', ), case_sensitive=True) special.register_special_command(self.refresh_completions, 'rehash', '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#', )) special.register_special_command( self.change_table_format, 'tableformat', '\\T', 'Change the table format used to output results.', aliases=('\\T', ), case_sensitive=True) special.register_special_command(self.execute_from_file, 'source', '\\. filename', 'Execute commands from file.', aliases=('\\.', )) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R', ), case_sensitive=True) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) def execute_from_file(self, arg, **_): if not arg: message = 'Missing required argument, filename.' return [(None, None, None, message)] try: with open(os.path.expanduser(arg), encoding='utf-8') as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] if (self.destructive_warning and confirm_destructive_query(query) is False): message = 'Wise choice. Command execution stopped.' return [(None, None, None, message)] return self.sqlexecute.run(query) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): log_file = self.config['main']['log_file'] log_level = self.config['main']['log_level'] level_map = { 'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": handler = logging.NullHandler() log_level = "CRITICAL" else: handler = logging.FileHandler(os.path.expanduser(log_file)) formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) root_logger = logging.getLogger('mycli') root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) root_logger.debug('Initializing mycli logging.') root_logger.debug('Log file %r.', log_file) def connect_uri(self, uri, local_infile=None, ssl=None): uri = urlparse(uri) database = uri.path[1:] # ignore the leading fwd slash self.connect(database, uri.username, uri.password, uri.hostname, uri.port, local_infile=local_infile, ssl=ssl) def read_my_cnf_files(self, files, keys): """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ cnf = read_config_files(files) sections = ['client'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) def get(key): result = None for sect in cnf: if sect in sections and key in cnf[sect]: result = cnf[sect][key] return result return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" merged = {} merged.update(ssl) prefix = 'ssl-' for k, v in cnf.items(): # skip unrelated options if not k.startswith(prefix): continue if v is None: continue # special case because PyMySQL argument is significantly different # from commandline if k == 'ssl-verify-server-cert': merged['check_hostname'] = v else: # use argument name just strip "ssl-" prefix arg = k[len(prefix):] merged[arg] = v return merged def connect(self, database='', user='', passwd='', host='', port='', socket='', charset='', local_infile='', ssl=''): cnf = { 'database': None, 'user': None, 'password': None, 'host': None, 'port': None, 'socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, 'ssl-ca': None, 'ssl-cert': None, 'ssl-key': None, 'ssl-cipher': None, 'ssl-verify-serer-cert': None, } cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. database = database or cnf['database'] if port or host: socket = '' else: socket = socket or cnf['socket'] user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] port = port or cnf['port'] ssl = ssl or {} passwd = passwd or cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' # Favor whichever local_infile option is set. for local_infile_option in (local_infile, cnf['local-infile'], cnf['loose-local-infile'], False): try: local_infile = str_to_bool(local_infile_option) break except (TypeError, ValueError): pass ssl = self.merge_ssl_with_cnf(ssl, cnf) # prune lone check_hostname=False if not any(v for v in ssl.values()): ssl = None # Connect to the database. def _connect(): try: self.sqlexecute = SQLExecute(database, user, passwd, host, port, socket, charset, local_infile, ssl) except OperationalError as e: if ('Access denied for user' in e.args[1]): new_passwd = click.prompt('Password', hide_input=True, show_default=False, type=str) self.sqlexecute = SQLExecute(database, user, new_passwd, host, port, socket, charset, local_infile, ssl) else: raise e try: if (socket is host is port is None) and not WIN: # Try a sensible default socket first (simplifies auth) # If we get a connection error, try tcp/ip localhost try: socket = '/var/run/mysqld/mysqld.sock' _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [ code for code in (2001, 2002, 2003) if code == e.args[0] ]: self.logger.debug('Database connection failed: %r.', e) self.logger.error("traceback: %r", traceback.format_exc()) self.logger.debug('Retrying over TCP/IP') self.echo(str(e), err=True) self.echo( 'Failed to connect by socket, retrying over TCP/IP', err=True) # Else fall back to TCP/IP localhost socket = "" host = 'localhost' port = 3306 _connect() else: raise e else: host = host or 'localhost' port = port or 3306 # Bad ports give particularly daft error messages try: port = int(port) except ValueError as e: self.echo( "Error: Invalid port number: '{0}'.".format(port), err=True, fg='red') exit(1) _connect() except Exception as e: # Connecting to a database could fail. self.logger.debug('Database connection failed: %r.', e) self.logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') exit(1) def handle_editor_command(self, cli, document): """ Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param cli: CommandLineInterface :param document: Document :return: Document """ # FIXME: using application.pre_run_callables like this here is not the best solution. # It's internal api of prompt_toolkit that may change. This was added to fix # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. saved_callables = cli.application.pre_run_callables while special.editor_command(document.text): filename = special.get_filename(document.text) query = (special.get_editor_query(document.text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) cli.current_buffer.document = Document(sql, cursor_position=len(sql)) cli.application.pre_run_callables = [] document = cli.run() continue cli.application.pre_run_callables = saved_callables return document def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() if self.smart_completion: self.refresh_completions() author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') key_binding_manager = mycli_bindings() if not self.less_chatty: print('Version:', __version__) print('Chat: https://gitter.im/dbcli/mycli') print('Mail: https://groups.google.com/forum/#!forum/mycli-users') print('Home: http://mycli.net') print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) def prompt_tokens(cli): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len( prompt) > self.max_len_prompt: prompt = self.get_prompt('\\d> ') return [(Token.Prompt, prompt)] def get_continuation_tokens(cli, width): continuation_prompt = self.get_prompt( self.prompt_continuation_format) return [(Token.Continuation, ' ' * (width - len(continuation_prompt)) + continuation_prompt)] def show_suggestion_tip(): return iterations < 2 def one_iteration(document=None): if document is None: document = self.cli.run() special.set_expanded_output(False) try: document = self.handle_editor_command(self.cli, document) except RuntimeError as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not document.text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(document.text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('Your call!') else: self.echo('Wise choice!') return # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered # mutating if any one of the component statements is mutating mutating = False try: logger.debug('sql: %r', document.text) special.write_tee( self.get_prompt(self.prompt_format) + document.text) if self.logfile: self.logfile.write('\n# %s\n' % datetime.now()) self.logfile.write(document.text) self.logfile.write('\n') successful = False start = time() res = sqlexecute.run(document.text) self.formatter.query = document.text successful = True result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 if (is_select(status) and cur and cur.rowcount > threshold): self.echo( 'The result set has more than {} rows.'.format( threshold), fg='red') if not click.confirm('Do you want to continue?'): self.echo("Aborted!", err=True, fg='red') break if self.auto_vertical_output: max_width = self.cli.output.get_size().columns else: max_width = None formatted = self.format_output( title, cur, headers, special.is_expanded_output(), max_width) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id logger.debug("connection id to kill: %r", connection_id_to_kill) # Restart connection to the database sqlexecute.connect() try: for title, cur, headers, status in sqlexecute.run( 'kill %s' % connection_id_to_kill): status_str = str(status).lower() if status_str.find('ok') > -1: logger.debug( "cancelled query, connection id: %r, sql: %r", connection_id_to_kill, document.text) self.echo("cancelled query", err=True, fg='red') except Exception as e: self.echo( 'Encountered error while cancelling query: {}'.format( e), err=True, fg='red') except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) if (e.args[0] in (2003, 2006, 2013)): logger.debug('Attempting to reconnect.') self.echo('Reconnecting...', fg='yellow') try: sqlexecute.connect() logger.debug('Reconnected successfully.') one_iteration(document) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: logger.debug('Reconnect failed. e: %r', e) self.echo(str(e), err=True, fg='red') # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: if is_dropping_database(document.text, self.sqlexecute.dbname): self.sqlexecute.dbname = None self.sqlexecute.connect() # Refresh the table names and column names if necessary. if need_completion_refresh(document.text): self.refresh_completions( reset=need_completion_reset(document.text)) finally: if self.logfile is False: self.echo("Warning: This query was not logged.", err=True, fg='red') query = Query(document.text, successful, mutating) self.query_history.append(query) get_toolbar_tokens = create_toolbar_tokens_func( self.completion_refresher.is_refreshing, show_suggestion_tip) layout = create_prompt_layout( lexer=MyCliLexer, multiline=True, get_prompt_tokens=prompt_tokens, get_continuation_tokens=get_continuation_tokens, get_bottom_toolbar_tokens=get_toolbar_tokens, display_completions_in_columns=self.wider_completion_menu, extra_input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor( chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()) ], reserve_space_for_menu=self.get_reserved_space()) with self._completer_lock: buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer, history=FileHistory( os.path.expanduser( os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))), auto_suggest=AutoSuggestFromHistory(), complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS application = Application( style=style_from_pygments(style_cls=self.output_style), layout=layout, buffer=buf, key_bindings_registry=key_binding_manager.registry, on_exit=AbortAction.RAISE_EXCEPTION, on_abort=AbortAction.RETRY, editing_mode=editing_mode, ignore_case=True) self.cli = CommandLineInterface(application=application, eventloop=create_eventloop()) try: while True: one_iteration() iterations += 1 except EOFError: special.close_tee() if not self.less_chatty: self.echo('Goodbye!') def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: click.echo(utf8tounicode(output), file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt( self.prompt_format).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.cli.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled( ) for i, line in enumerate(output, 1): self.log_output(line) special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled( ): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: self.log_output(status) click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. if not os.environ.get('LESS'): os.environ['LESS'] = '-RXF' cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) if cnf['pager']: special.set_pager(cnf['pager']) self.explicit_pager = True else: self.explicit_pager = False if cnf['skip-pager']: special.disable_pager() def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, { 'smart_completion': self.smart_completion, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing }) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer # When mycli is first launched we call refresh_completions before # instantiating the cli object. So it is necessary to check if cli # exists before trying the replace the completer object in cli. if self.cli: self.cli.current_buffer.completer = new_completer if self.cli: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.cli.request_redraw() def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\p', str(sqlexecute.port)) string = string.replace('\\_', ' ') return string def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result self.formatter.query = query output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'disable_numparse': True, 'preserve_whitespace': True, 'preprocessors': (preprocessors.align_decimals, ), 'style': self.output_style } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): def get_col_type(col): col_type = FIELD_TYPES.get(col[1], text_type) return col_type if type(col_type) is type else text_type column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if (not expanded and max_width and headers and cur and len(first_line) > max_width): formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
def format_output(title, cur, headers, status, settings): output = [] expanded = (settings.expanded or settings.table_format == 'vertical') table_format = ('vertical' if settings.expanded else settings.table_format) max_width = settings.max_width case_function = settings.case_function formatter = TabularOutputFormatter(format_name=table_format) def format_array(val): if val is None: return settings.missingval if not isinstance(val, list): return val return '{' + ','.join(text_type(format_array(e)) for e in val) + '}' def format_arrays(data, headers, **_): data = list(data) for row in data: row[:] = [ format_array(val) if isinstance(val, list) else val for val in row ] return data, headers output_kwargs = { 'sep_title': 'RECORD {n}', 'sep_character': '-', 'sep_length': (1, 25), 'missing_value': settings.missingval, 'integer_format': settings.dcmlfmt, 'float_format': settings.floatfmt, 'preprocessors': (format_numbers, format_arrays), 'disable_numparse': True, 'preserve_whitespace': True, 'style': settings.style_output } if not settings.floatfmt: output_kwargs['preprocessors'] = (align_decimals, ) if title: # Only print the title if it's not None. output.append(title) if cur: headers = [case_function(utf8tounicode(x)) for x in headers] if max_width is not None: cur = list(cur) column_types = None if hasattr(cur, 'description'): column_types = [] for d in cur.description: if d[1] in psycopg2.extensions.DECIMAL.values or \ d[1] in psycopg2.extensions.FLOAT.values: column_types.append(float) if d[1] == psycopg2.extensions.INTEGER.values or \ d[1] in psycopg2.extensions.LONGINTEGER.values: column_types.append(int) else: column_types.append(text_type) formatted = formatter.format_output(cur, headers, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if not expanded and max_width and len(first_line) > max_width and headers: formatted = formatter.format_output( cur, headers, format_name='vertical', column_types=None, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) if status: # Only print the status if it's not None. output = itertools.chain(output, [status]) return output
def format_output(title, cur, headers, status, settings): output = [] expanded = settings.expanded or settings.table_format == "vertical" table_format = "vertical" if settings.expanded else settings.table_format max_width = settings.max_width case_function = settings.case_function formatter = TabularOutputFormatter(format_name=table_format) def format_array(val): if val is None: return settings.missingval if not isinstance(val, list): return val return "{" + ",".join(str(format_array(e)) for e in val) + "}" def format_arrays(data, headers, **_): data = list(data) for row in data: row[:] = [ format_array(val) if isinstance(val, list) else val for val in row ] return data, headers output_kwargs = { "sep_title": "RECORD {n}", "sep_character": "-", "sep_length": (1, 25), "missing_value": settings.missingval, "integer_format": settings.dcmlfmt, "float_format": settings.floatfmt, "preprocessors": (format_numbers, format_arrays), "disable_numparse": True, "preserve_whitespace": True, "style": settings.style_output, } if not settings.floatfmt: output_kwargs["preprocessors"] = (align_decimals,) if title: # Only print the title if it's not None. output.append(title) if cur: headers = [case_function(x) for x in headers] if max_width is not None: cur = list(cur) column_types = None if hasattr(cur, "description"): column_types = [] for d in cur.description: if ( d[1] in psycopg2.extensions.DECIMAL.values or d[1] in psycopg2.extensions.FLOAT.values ): column_types.append(float) if ( d[1] == psycopg2.extensions.INTEGER.values or d[1] in psycopg2.extensions.LONGINTEGER.values ): column_types.append(int) else: column_types.append(str) formatted = formatter.format_output(cur, headers, **output_kwargs) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if not expanded and max_width and len(first_line) > max_width and headers: formatted = formatter.format_output( cur, headers, format_name="vertical", column_types=None, **output_kwargs ) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) if status: # Only print the status if it's not None. output = itertools.chain(output, [status]) return output
class LiteCli(object): default_prompt = "\\d> " max_len_prompt = 45 def __init__( self, sqlexecute=None, prompt=None, logfile=None, auto_vertical_output=False, warn=None, liteclirc=None, ): self.sqlexecute = sqlexecute self.logfile = logfile # Load config. c = self.config = get_config(liteclirc) self.multi_line = c["main"].as_bool("multi_line") self.key_bindings = c["main"]["key_bindings"] special.set_favorite_queries(self.config) self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) self.formatter.litecli = self self.syntax_style = c["main"]["syntax_style"] self.less_chatty = c["main"].as_bool("less_chatty") self.cli_style = c["colors"] self.output_style = style_factory_output(self.syntax_style, self.cli_style) self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn self.login_path_as_host = c["main"].as_bool("login_path_as_host") # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or c["main"].as_bool( "auto_vertical_output" ) # audit log if self.logfile is None and "audit_log" in c["main"]: try: self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a") except (IOError, OSError): self.echo( "Error: Unable to open the audit log file. Your queries will not be logged.", err=True, fg="red", ) self.logfile = False self.completion_refresher = CompletionRefresher() self.logger = logging.getLogger(__name__) self.initialize_logging() prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"] self.prompt_format = ( prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt ) self.prompt_continuation_format = c["main"]["prompt_continuation"] keyword_casing = c["main"].get("keyword_casing", "auto") self.query_history = [] # Initialize completer. self.completer = SQLCompleter( supported_formats=self.formatter.supported_formats, keyword_casing=keyword_casing, ) self._completer_lock = threading.Lock() # Register custom special commands. self.register_special_commands() self.prompt_app = None def register_special_commands(self): special.register_special_command( self.change_db, ".open", ".open", "Change to a new database.", aliases=("use", "\\u"), ) special.register_special_command( self.refresh_completions, "rehash", "\\#", "Refresh auto-completions.", arg_type=NO_QUERY, aliases=("\\#",), ) special.register_special_command( self.change_table_format, ".mode", "\\T", "Change the table format used to output results.", aliases=("tableformat", "\\T"), case_sensitive=True, ) special.register_special_command( self.execute_from_file, "source", "\\. filename", "Execute commands from file.", aliases=("\\.",), ) special.register_special_command( self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=("\\R",), case_sensitive=True, ) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, "Changed table format to {}".format(arg)) except ValueError: msg = "Table format {} not recognized. Allowed formats:".format(arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) self.refresh_completions() yield ( None, None, None, 'You are now connected to database "%s"' % (self.sqlexecute.dbname), ) def execute_from_file(self, arg, **_): if not arg: message = "Missing required argument, filename." return [(None, None, None, message)] try: with open(os.path.expanduser(arg), encoding="utf-8") as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] if self.destructive_warning and confirm_destructive_query(query) is False: message = "Wise choice. Command execution stopped." return [(None, None, None, message)] return self.sqlexecute.run(query) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = "Missing required argument, format." return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" ensure_dir_exists(log_file) log_level = self.config["main"]["log_level"] level_map = { "CRITICAL": logging.CRITICAL, "ERROR": logging.ERROR, "WARNING": logging.WARNING, "INFO": logging.INFO, "DEBUG": logging.DEBUG, } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) else: self.echo( 'Error: Unable to open the log file "{}".'.format(log_file), err=True, fg="red", ) return formatter = logging.Formatter( "%(asctime)s (%(process)d/%(threadName)s) " "%(name)s %(levelname)s - %(message)s" ) handler.setFormatter(formatter) root_logger = logging.getLogger("litecli") root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) root_logger.debug("Initializing litecli logging.") root_logger.debug("Log file %r.", log_file) def read_my_cnf_files(self, keys): """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ cnf = self.config sections = ["main"] def get(key): result = None for sect in cnf: if sect in sections and key in cnf[sect]: result = cnf[sect][key] return result return {x: get(x) for x in keys} def connect(self, database=""): cnf = {"database": None} cnf = self.read_my_cnf_files(cnf.keys()) # Fall back to config values only if user did not specify a value. database = database or cnf["database"] # Connect to the database. def _connect(): self.sqlexecute = SQLExecute(database) try: _connect() except Exception as e: # Connecting to a database could fail. self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") exit(1) def handle_editor_command(self, text): """Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param text: Document :return: Document """ while special.editor_command(text): filename = special.get_filename(text) query = special.get_editor_query(text) or self.get_last_query() sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) while True: try: text = self.prompt_app.prompt(default=sql) break except KeyboardInterrupt: sql = "" continue return text def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() self.refresh_completions() history_file = config_location() + "history" if dir_path_exists(history_file): history = FileHistory(history_file) else: history = None self.echo( 'Error: Unable to open the history file "{}". ' "Your query history will not be saved.".format(history_file), err=True, fg="red", ) key_bindings = cli_bindings(self) if not self.less_chatty: print("Version:", __version__) print("Mail: https://groups.google.com/forum/#!forum/litecli-users") print("GitHub: https://github.com/dbcli/litecli") # print("Home: https://litecli.com") def get_message(): prompt = self.get_prompt(self.prompt_format) if ( self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt ): prompt = self.get_prompt("\\d> ") return [("class:prompt", prompt)] def get_continuation(width, line_number, is_soft_wrap): continuation = " " * (width - 1) + " " return [("class:continuation", continuation)] def show_suggestion_tip(): return iterations < 2 def one_iteration(text=None): if text is None: try: text = self.prompt_app.prompt() except KeyboardInterrupt: return special.set_expanded_output(False) try: text = self.handle_editor_command(text) except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") return if not text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo("Your call!") else: self.echo("Wise choice!") return # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered # mutating if any one of the component statements is mutating mutating = False try: logger.debug("sql: %r", text) special.write_tee(self.get_prompt(self.prompt_format) + text) if self.logfile: self.logfile.write("\n# %s\n" % datetime.now()) self.logfile.write(text) self.logfile.write("\n") successful = False start = time() res = sqlexecute.run(text) self.formatter.query = text successful = True result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 if is_select(status) and cur and cur.rowcount > threshold: self.echo( "The result set has more than {} rows.".format(threshold), fg="red", ) if not confirm("Do you want to continue?"): self.echo("Aborted!", err=True, fg="red") break if self.auto_vertical_output: max_width = self.prompt_app.output.get_size().columns else: max_width = None formatted = self.format_output( title, cur, headers, special.is_expanded_output(), max_width ) t = time() - start try: if result_count > 0: self.echo("") try: self.output(formatted, status) except KeyboardInterrupt: pass self.echo("Time: %0.03fs" % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id logger.debug("connection id to kill: %r", connection_id_to_kill) # Restart connection to the database sqlexecute.connect() try: for title, cur, headers, status in sqlexecute.run( "kill %s" % connection_id_to_kill ): status_str = str(status).lower() if status_str.find("ok") > -1: logger.debug( "cancelled query, connection id: %r, sql: %r", connection_id_to_kill, text, ) self.echo("cancelled query", err=True, fg="red") except Exception as e: self.echo( "Encountered error while cancelling query: {}".format(e), err=True, fg="red", ) except NotImplementedError: self.echo("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) if e.args[0] in (2003, 2006, 2013): logger.debug("Attempting to reconnect.") self.echo("Reconnecting...", fg="yellow") try: sqlexecute.connect() logger.debug("Reconnected successfully.") one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: logger.debug("Reconnect failed. e: %r", e) self.echo(str(e), err=True, fg="red") # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") except Exception as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") else: if is_dropping_database(text, self.sqlexecute.dbname): self.sqlexecute.dbname = None self.sqlexecute.connect() # Refresh the table names and column names if necessary. if need_completion_refresh(text): self.refresh_completions(reset=need_completion_reset(text)) finally: if self.logfile is False: self.echo("Warning: This query was not logged.", err=True, fg="red") query = Query(text, successful, mutating) self.query_history.append(query) get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip) if self.wider_completion_menu: complete_style = CompleteStyle.MULTI_COLUMN else: complete_style = CompleteStyle.COLUMN with self._completer_lock: if self.key_bindings == "vi": editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS self.prompt_app = PromptSession( lexer=PygmentsLexer(LiteCliLexer), reserve_space_for_menu=self.get_reserved_space(), message=get_message, prompt_continuation=get_continuation, bottom_toolbar=get_toolbar_tokens, complete_style=complete_style, input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor(chars="[](){}"), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), ) ], tempfile_suffix=".sql", completer=DynamicCompleter(lambda: self.completer), history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=True, multiline=cli_is_multiline(self), style=style_factory(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, enable_system_prompt=True, enable_suspend=True, editing_mode=editing_mode, search_ignore_case=True, ) try: while True: one_iteration() iterations += 1 except EOFError: special.close_tee() if not self.less_chatty: self.echo("Goodbye!") def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: click.echo(utf8tounicode(output), file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = ( self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 2 ) if status: margin += 1 + status.count("\n") return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.prompt_app.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled() for i, line in enumerate(output, 1): self.log_output(line) special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled(): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: self.log_output(status) click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. if not os.environ.get("LESS"): os.environ["LESS"] = "-RXF" cnf = self.read_my_cnf_files(["pager", "skip-pager"]) if cnf["pager"]: special.set_pager(cnf["pager"]) self.explicit_pager = True else: self.explicit_pager = False if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): special.disable_pager() def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, { "supported_formats": self.formatter.supported_formats, "keyword_casing": self.completer.keyword_casing, }, ) return [ (None, None, None, "Auto-completion refresh started in the background.") ] def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer if self.prompt_app: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.prompt_app.app.invalidate() def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None ) def get_prompt(self, string): self.logger.debug("Getting prompt") sqlexecute = self.sqlexecute now = datetime.now() string = string.replace("\\d", sqlexecute.dbname or "(none)") string = string.replace("\\n", "\n") string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) string = string.replace("\\m", now.strftime("%M")) string = string.replace("\\P", now.strftime("%p")) string = string.replace("\\R", now.strftime("%H")) string = string.replace("\\r", now.strftime("%I")) string = string.replace("\\s", now.strftime("%S")) string = string.replace("\\_", " ") return string def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result self.formatter.query = query output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == "vertical" output = [] output_kwargs = { "dialect": "unix", "disable_numparse": True, "preserve_whitespace": True, "preprocessors": (preprocessors.align_decimals,), "style": self.output_style, } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, "description"): def get_col_type(col): # col_type = FIELD_TYPES.get(col[1], text_type) # return col_type if type(col_type) is type else text_type return text_type column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name="vertical" if expanded else None, column_types=column_types, **output_kwargs ) if isinstance(formatted, (text_type)): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if ( not expanded and max_width and headers and cur and len(first_line) > max_width ): formatted = self.formatter.format_output( cur, headers, format_name="vertical", column_types=column_types, **output_kwargs ) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = 0.45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
class AthenaCli(object): DEFAULT_PROMPT = '\\d@\\r> ' MAX_LEN_PROMPT = 45 def __init__(self, region, aws_access_key_id, aws_secret_access_key, s3_staging_dir, athenaclirc, profile, database): config_files = (DEFAULT_CONFIG_FILE, athenaclirc) _cfg = self.config = read_config_files(config_files) self.init_logging(_cfg['main']['log_file'], _cfg['main']['log_level']) aws_config = AWSConfig(aws_access_key_id, aws_secret_access_key, region, s3_staging_dir, profile, _cfg) try: self.connect(aws_config, database) except Exception as e: self.echo(str(e), err=True, fg='red') err_msg = ''' There was an error while connecting to AWS Athena. It could be caused due to missing/incomplete configuration. Please verify the configuration in %s and run athenacli again. For more details about the error, you can check the log file: %s''' % ( ATHENACLIRC, _cfg['main']['log_file']) self.echo(err_msg) LOGGER.exception('error: %r', e) sys.exit(1) special.set_timing_enabled(_cfg['main'].as_bool('timing')) self.multi_line = _cfg['main'].as_bool('multi_line') self.key_bindings = _cfg['main']['key_bindings'] self.prompt = _cfg['main']['prompt'] or self.DEFAULT_PROMPT self.destructive_warning = _cfg['main']['destructive_warning'] self.syntax_style = _cfg['main']['syntax_style'] self.prompt_continuation_format = _cfg['main']['prompt_continuation'] self.formatter = TabularOutputFormatter(_cfg['main']['table_format']) self.formatter.cli = self sql_format.register_new_formatter(self.formatter) self.cli_style = _cfg['colors'] self.output_style = style_factory_output(self.syntax_style, self.cli_style) self.completer = AthenaCompleter() self._completer_lock = threading.Lock() self.completion_refresher = CompletionRefresher() self.prompt_app = None self.query_history = [] # Register custom special commands. self.register_special_commands() def init_logging(self, log_file, log_level_str): file_path = os.path.expanduser(log_file) if not os.path.exists(file_path): mkdir_p(os.path.dirname(file_path)) handler = logging.FileHandler(os.path.expanduser(log_file)) log_level_map = { 'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG, } log_level = log_level_map[log_level_str.upper()] formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) LOGGER.addHandler(handler) LOGGER.setLevel(log_level) root_logger = logging.getLogger('athenacli') root_logger.addHandler(handler) root_logger.setLevel(log_level) root_logger.debug('Initializing athenacli logging.') root_logger.debug('Log file %r.', log_file) pgspecial_logger = logging.getLogger('special') pgspecial_logger.addHandler(handler) pgspecial_logger.setLevel(log_level) def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u', )) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R', ), case_sensitive=True) special.register_special_command( self.change_table_format, 'tableformat', '\\T', 'Change the table format used to output results.', aliases=('\\T', ), case_sensitive=True) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) yield (None, None, None, 'You are now connected to database "%s"' % self.sqlexecute.database) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def connect(self, aws_config, database): self.sqlexecute = SQLExecute( aws_access_key_id=aws_config.aws_access_key_id, aws_secret_access_key=aws_config.aws_secret_access_key, region_name=aws_config.region, s3_staging_dir=aws_config.s3_staging_dir, role_arn=aws_config.role_arn, database=database) def handle_editor_command(self, text): """ Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param text: str :return: Document """ while special.editor_command(text): filename = special.get_filename(text) query = (special.get_editor_query(text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) while True: try: text = self.prompt_app.prompt(default=sql) break except KeyboardInterrupt: sql = '' continue return text def run_query(self, query, new_line=True): """Runs *query*.""" if (self.destructive_warning and confirm_destructive_query(query) is False): message = 'Wise choice. Command execution stopped.' click.echo(message) return results = self.sqlexecute.run(query) for result in results: title, rows, headers, _ = result self.formatter.query = query output = self.format_output(title, rows, headers) for line in output: click.echo(line, nl=new_line) def run_cli(self): self.iterations = 0 self.configure_pager() self.refresh_completions() history_file = os.path.expanduser(self.config['main']['history_file']) history = FileHistory(history_file) self._build_prompt_app(history) def one_iteration(): try: text = self.prompt_app.prompt() except KeyboardInterrupt: return special.set_expanded_output(False) try: text = self.handle_editor_command(text) except RuntimeError as e: LOGGER.error("sql: %r, error: %r", text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('Your call!') else: self.echo('Wise choice!') return mutating = False try: LOGGER.debug('sql: %r', text) special.write_tee(self.get_prompt(self.prompt) + text) successful = False start = time() res = self.sqlexecute.run(text) successful = True threshold = 1000 result_count = 0 for title, rows, headers, status in res: if rows and len(rows) > threshold: self.echo( 'The result set has more than {} rows.'.format( threshold), fg='red') if not confirm('Do you want to continue?'): self.echo('Aborted!', err=True, fg='red') break formatted = self.format_output( title, rows, headers, special.is_expanded_output(), None) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: pass except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: LOGGER.debug("Exception: %r", e) LOGGER.error("sql: %r, error: %r", text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: LOGGER.error("sql: %r, error: %r", text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: # Refresh the table names and column names if necessary. if need_completion_refresh(text): self.refresh_completions() query = Query(text, successful, mutating) self.query_history.append(query) try: while True: one_iteration() self.iterations += 1 except EOFError: special.close_tee() def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt( self.prompt).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.prompt_app.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled( ) for i, line in enumerate(output, 1): special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled( ): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: click.secho(status) def configure_pager(self): self.explicit_pager = False if not self.config['main'].as_bool('enable_pager'): special.disable_pager() def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'disable_numparse': True, 'preserve_whitespace': True, 'preprocessors': (preprocessors.align_decimals, ), 'style': self.output_style } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): column_types = [str for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, str): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if (not expanded and max_width and headers and cur and len(first_line) > max_width): formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ click.secho(s, **kwargs) def refresh_completions(self): with self._completer_lock: self.completer.reset_completions() completer_options = { 'smart_completion': True, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing } self.completion_refresher.refresh(self.sqlexecute, self._on_completions_refreshed, completer_options) def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer if self.prompt_app: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.prompt_app.app.invalidate() def _build_prompt_app(self, history): key_bindings = cli_bindings(self) def get_message(): prompt = self.get_prompt(self.prompt) if len(prompt) > self.MAX_LEN_PROMPT: prompt = self.get_prompt('\\r:\\d> ') return [('class:prompt', prompt)] def get_continuation(width, line_number, is_soft_wrap): continuation = ' ' * (width - 1) + ' ' return [('class:continuation', continuation)] def show_suggestion_tip(): return self.iterations < 2 get_toolbar_tokens = create_toolbar_tokens_func( self, show_suggestion_tip) with self._completer_lock: if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS self.prompt_app = PromptSession( lexer=PygmentsLexer(Lexer), reserve_space_for_menu=self.get_reserved_space(), message=get_message, prompt_continuation=get_continuation, bottom_toolbar=get_toolbar_tokens, complete_style=CompleteStyle.COLUMN, input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor( chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()) ], tempfile_suffix='.sql', completer=DynamicCompleter(lambda: self.completer), history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=True, multiline=cli_is_multiline(self), style=style_factory(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, enable_system_prompt=True, editing_mode=editing_mode, search_ignore_case=True) def get_prompt(self, string): sqlexecute = self.sqlexecute now = datetime.now() string = string.replace('\\r', sqlexecute.region_name or '(none)') string = string.replace('\\d', sqlexecute.database or '(none)') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\s', now.strftime('%S')) return string def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
class MyCli(object): default_prompt = '\\t \\u@\\h:\\d> ' max_len_prompt = 45 defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. cnf_files = [ '/etc/my.cnf', '/etc/mysql/my.cnf', '/usr/local/etc/my.cnf', '~/.my.cnf' ] system_config_files = [ '/etc/myclirc', ] default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') def __init__(self, sqlexecute=None, prompt=None, logfile=None, defaults_suffix=None, defaults_file=None, login_path=None, auto_vertical_output=False, warn=None, myclirc="~/.myclirc"): self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. # If defaults_file is specified then override the class variable with # defaults_file. if defaults_file: self.cnf_files = [defaults_file] # Load config. config_files = ([self.default_config_file] + self.system_config_files + [myclirc]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) sql_format.register_new_formatter(self.formatter) self.formatter.mycli = self self.syntax_style = c['main']['syntax_style'] self.less_chatty = c['main'].as_bool('less_chatty') self.cli_style = c['colors'] self.output_style = style_factory(self.syntax_style, self.cli_style) self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') c_dest_warning = c['main'].as_bool('destructive_warning') self.destructive_warning = c_dest_warning if warn is None else warn self.login_path_as_host = c['main'].as_bool('login_path_as_host') # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or \ c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files: write_default_config(self.default_config_file, myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: try: self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') except (IOError, OSError) as e: self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red') self.logfile = False self.completion_refresher = CompletionRefresher() self.logger = logging.getLogger(__name__) self.initialize_logging() prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt self.prompt_continuation_format = c['main']['prompt_continuation'] keyword_casing = c['main'].get('keyword_casing', 'auto') self.query_history = [] # Initialize completer. self.smart_completion = c['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, supported_formats=self.formatter.supported_formats, keyword_casing=keyword_casing) self._completer_lock = threading.Lock() # Register custom special commands. self.register_special_commands() # Load .mylogin.cnf if it exists. mylogin_cnf_path = get_mylogin_cnf_path() if mylogin_cnf_path: mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path) if mylogin_cnf_path and mylogin_cnf: # .mylogin.cnf gets read last, even if defaults_file is specified. self.cnf_files.append(mylogin_cnf) elif mylogin_cnf_path and not mylogin_cnf: # There was an error reading the login path file. print('Error: Unable to read login path file.') self.cli = None def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u',)) special.register_special_command(self.change_db, 'connect', '\\r', 'Reconnect to the database. Optional database argument.', aliases=('\\r', ), case_sensitive=True) special.register_special_command(self.refresh_completions, 'rehash', '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) special.register_special_command( self.change_table_format, 'tableformat', '\\T', 'Change the table format used to output results.', aliases=('\\T',), case_sensitive=True) special.register_special_command(self.execute_from_file, 'source', '\\. filename', 'Execute commands from file.', aliases=('\\.',)) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) def execute_from_file(self, arg, **_): if not arg: message = 'Missing required argument, filename.' return [(None, None, None, message)] try: with open(os.path.expanduser(arg), encoding='utf-8') as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] if (self.destructive_warning and confirm_destructive_query(query) is False): message = 'Wise choice. Command execution stopped.' return [(None, None, None, message)] return self.sqlexecute.run(query) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): log_file = os.path.expanduser(self.config['main']['log_file']) log_level = self.config['main']['log_level'] level_map = {'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) else: self.echo( 'Error: Unable to open the log file "{}".'.format(log_file), err=True, fg='red') return formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) root_logger = logging.getLogger('mycli') root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) root_logger.debug('Initializing mycli logging.') root_logger.debug('Log file %r.', log_file) def connect_uri(self, uri, local_infile=None, ssl=None): uri = urlparse(uri) database = uri.path[1:] # ignore the leading fwd slash self.connect(database, unquote(uri.username), unquote(uri.password), uri.hostname, uri.port, local_infile=local_infile, ssl=ssl) def read_my_cnf_files(self, files, keys): """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ cnf = read_config_files(files) sections = ['client'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) def get(key): result = None for sect in cnf: if sect in sections and key in cnf[sect]: result = cnf[sect][key] return result return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" merged = {} merged.update(ssl) prefix = 'ssl-' for k, v in cnf.items(): # skip unrelated options if not k.startswith(prefix): continue if v is None: continue # special case because PyMySQL argument is significantly different # from commandline if k == 'ssl-verify-server-cert': merged['check_hostname'] = v else: # use argument name just strip "ssl-" prefix arg = k[len(prefix):] merged[arg] = v return merged def connect(self, database='', user='', passwd='', host='', port='', socket='', charset='', local_infile='', ssl=''): cnf = {'database': None, 'user': None, 'password': None, 'host': None, 'port': None, 'socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, 'ssl-ca': None, 'ssl-cert': None, 'ssl-key': None, 'ssl-cipher': None, 'ssl-verify-serer-cert': None, } cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. database = database or cnf['database'] if port or host: socket = '' else: socket = socket or cnf['socket'] user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] port = port or cnf['port'] ssl = ssl or {} passwd = passwd or cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' # Favor whichever local_infile option is set. for local_infile_option in (local_infile, cnf['local-infile'], cnf['loose-local-infile'], False): try: local_infile = str_to_bool(local_infile_option) break except (TypeError, ValueError): pass ssl = self.merge_ssl_with_cnf(ssl, cnf) # prune lone check_hostname=False if not any(v for v in ssl.values()): ssl = None # Connect to the database. def _connect(): try: self.sqlexecute = SQLExecute(database, user, passwd, host, port, socket, charset, local_infile, ssl) except OperationalError as e: if ('Access denied for user' in e.args[1]): new_passwd = click.prompt('Password', hide_input=True, show_default=False, type=str, err=True) self.sqlexecute = SQLExecute(database, user, new_passwd, host, port, socket, charset, local_infile, ssl) else: raise e try: if (socket is host is port is None) and not WIN: # Try a sensible default socket first (simplifies auth) # If we get a connection error, try tcp/ip localhost try: socket = '/var/run/mysqld/mysqld.sock' _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [code for code in (2001, 2002, 2003) if code == e.args[0]]: self.logger.debug('Database connection failed: %r.', e) self.logger.error( "traceback: %r", traceback.format_exc()) self.logger.debug('Retrying over TCP/IP') self.echo(str(e), err=True) self.echo( 'Failed to connect by socket, retrying over TCP/IP', err=True) # Else fall back to TCP/IP localhost socket = "" host = 'localhost' port = 3306 _connect() else: raise e else: host = host or 'localhost' port = port or 3306 # Bad ports give particularly daft error messages try: port = int(port) except ValueError as e: self.echo("Error: Invalid port number: '{0}'.".format(port), err=True, fg='red') exit(1) _connect() except Exception as e: # Connecting to a database could fail. self.logger.debug('Database connection failed: %r.', e) self.logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') exit(1) def handle_editor_command(self, cli, document): """ Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param cli: CommandLineInterface :param document: Document :return: Document """ # FIXME: using application.pre_run_callables like this here is not the best solution. # It's internal api of prompt_toolkit that may change. This was added to fix # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. saved_callables = cli.application.pre_run_callables while special.editor_command(document.text): filename = special.get_filename(document.text) query = (special.get_editor_query(document.text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) cli.current_buffer.document = Document(sql, cursor_position=len(sql)) cli.application.pre_run_callables = [] document = cli.run() continue cli.application.pre_run_callables = saved_callables return document def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() if self.smart_completion: self.refresh_completions() author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') history_file = os.path.expanduser( os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) if dir_path_exists(history_file): history = FileHistory(history_file) else: history = None self.echo( 'Error: Unable to open the history file "{}". ' 'Your query history will not be saved.'.format(history_file), err=True, fg='red') key_binding_manager = mycli_bindings() if not self.less_chatty: print('Version:', __version__) print('Chat: https://gitter.im/dbcli/mycli') print('Mail: https://groups.google.com/forum/#!forum/mycli-users') print('Home: http://mycli.net') print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) def prompt_tokens(cli): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt('\\d> ') return [(Token.Prompt, prompt)] def get_continuation_tokens(cli, width): continuation_prompt = self.get_prompt(self.prompt_continuation_format) return [(Token.Continuation, ' ' * (width - len(continuation_prompt)) + continuation_prompt)] def show_suggestion_tip(): return iterations < 2 def one_iteration(document=None): if document is None: document = self.cli.run() special.set_expanded_output(False) try: document = self.handle_editor_command(self.cli, document) except RuntimeError as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not document.text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(document.text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('Your call!') else: self.echo('Wise choice!') return # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered # mutating if any one of the component statements is mutating mutating = False try: logger.debug('sql: %r', document.text) special.write_tee(self.get_prompt(self.prompt_format) + document.text) if self.logfile: self.logfile.write('\n# %s\n' % datetime.now()) self.logfile.write(document.text) self.logfile.write('\n') successful = False start = time() res = sqlexecute.run(document.text) self.formatter.query = document.text successful = True result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 if (is_select(status) and cur and cur.rowcount > threshold): self.echo('The result set has more than {} rows.'.format( threshold), fg='red') if not confirm('Do you want to continue?'): self.echo("Aborted!", err=True, fg='red') break if self.auto_vertical_output: max_width = self.cli.output.get_size().columns else: max_width = None formatted = self.format_output( title, cur, headers, special.is_expanded_output(), max_width) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id logger.debug("connection id to kill: %r", connection_id_to_kill) # Restart connection to the database sqlexecute.connect() try: for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill): status_str = str(status).lower() if status_str.find('ok') > -1: logger.debug("cancelled query, connection id: %r, sql: %r", connection_id_to_kill, document.text) self.echo("cancelled query", err=True, fg='red') except Exception as e: self.echo('Encountered error while cancelling query: {}'.format(e), err=True, fg='red') except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) if (e.args[0] in (2003, 2006, 2013)): logger.debug('Attempting to reconnect.') self.echo('Reconnecting...', fg='yellow') try: sqlexecute.connect() logger.debug('Reconnected successfully.') one_iteration(document) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: logger.debug('Reconnect failed. e: %r', e) self.echo(str(e), err=True, fg='red') # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: if is_dropping_database(document.text, self.sqlexecute.dbname): self.sqlexecute.dbname = None self.sqlexecute.connect() # Refresh the table names and column names if necessary. if need_completion_refresh(document.text): self.refresh_completions( reset=need_completion_reset(document.text)) finally: if self.logfile is False: self.echo("Warning: This query was not logged.", err=True, fg='red') query = Query(document.text, successful, mutating) self.query_history.append(query) get_toolbar_tokens = create_toolbar_tokens_func( self.completion_refresher.is_refreshing, show_suggestion_tip) layout = create_prompt_layout( lexer=MyCliLexer, multiline=True, get_prompt_tokens=prompt_tokens, get_continuation_tokens=get_continuation_tokens, get_bottom_toolbar_tokens=get_toolbar_tokens, display_completions_in_columns=self.wider_completion_menu, extra_input_processors=[ConditionalProcessor( processor=HighlightMatchingBracketProcessor(chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() )], reserve_space_for_menu=self.get_reserved_space() ) with self._completer_lock: buf = CLIBuffer( always_multiline=self.multi_line, completer=self.completer, history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS application = Application( style=style_from_pygments(style_cls=self.output_style), layout=layout, buffer=buf, key_bindings_registry=key_binding_manager.registry, on_exit=AbortAction.RAISE_EXCEPTION, on_abort=AbortAction.RETRY, editing_mode=editing_mode, ignore_case=True) self.cli = CommandLineInterface(application=application, eventloop=create_eventloop()) try: while True: one_iteration() iterations += 1 except EOFError: special.close_tee() if not self.less_chatty: self.echo('Goodbye!') def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: click.echo(utf8tounicode(output), file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.cli.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled() for i, line in enumerate(output, 1): self.log_output(line) special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled(): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: self.log_output(status) click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. if not os.environ.get('LESS'): os.environ['LESS'] = '-RXF' cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) if cnf['pager']: special.set_pager(cnf['pager']) self.explicit_pager = True else: self.explicit_pager = False if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): special.disable_pager() def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, {'smart_completion': self.smart_completion, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing}) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer # When mycli is first launched we call refresh_completions before # instantiating the cli object. So it is necessary to check if cli # exists before trying the replace the completer object in cli. if self.cli: self.cli.current_buffer.completer = new_completer if self.cli: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.cli.request_redraw() def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\p', str(sqlexecute.port)) string = string.replace('\\_', ' ') return string def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result self.formatter.query = query output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'disable_numparse': True, 'preserve_whitespace': True, 'preprocessors': (preprocessors.align_decimals, ), 'style': self.output_style } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): def get_col_type(col): col_type = FIELD_TYPES.get(col[1], text_type) return col_type if type(col_type) is type else text_type column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if (not expanded and max_width and headers and cur and len(first_line) > max_width): formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
def format_output(title, cur, headers, status, settings): output = [] expanded = (settings.expanded or settings.table_format == 'vertical') table_format = ('vertical' if settings.expanded else settings.table_format) max_width = settings.max_width case_function = settings.case_function formatter = TabularOutputFormatter(format_name=table_format) def format_array(val): if val is None: return settings.missingval if not isinstance(val, list): return val return '{' + ','.join(text_type(format_array(e)) for e in val) + '}' def format_arrays(data, headers, **_): data = list(data) for row in data: row[:] = [ format_array(val) if isinstance(val, list) else val for val in row ] return data, headers output_kwargs = { 'sep_title': 'RECORD {n}', 'sep_character': '-', 'sep_length': (1, 25), 'missing_value': settings.missingval, 'integer_format': settings.dcmlfmt, 'float_format': settings.floatfmt, 'preprocessors': (format_numbers, format_arrays), 'disable_numparse': True, 'preserve_whitespace': True, 'style': settings.style_output } if not settings.floatfmt: output_kwargs['preprocessors'] = (align_decimals, ) if title: # Only print the title if it's not None. output.append(title) if cur: headers = [case_function(utf8tounicode(x)) for x in headers] if max_width is not None: cur = list(cur) column_types = None if hasattr(cur, 'description'): column_types = [] for d in cur.description: if d[1] in psycopg2.extensions.DECIMAL.values or \ d[1] in psycopg2.extensions.FLOAT.values: column_types.append(float) if d[1] == psycopg2.extensions.INTEGER.values or \ d[1] in psycopg2.extensions.LONGINTEGER.values: column_types.append(int) else: column_types.append(text_type) formatted = formatter.format_output(cur, headers, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if not expanded and max_width and len(first_line) > max_width and headers: formatted = formatter.format_output( cur, headers, format_name='vertical', column_types=None, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) if status: # Only print the status if it's not None. output = itertools.chain(output, [status]) return output
class SQLiteCli(object): DEFAULT_PROMPT = 'sqlite> ' MAX_LEN_PROMPT = 45 def __init__(self, prompt=None, sqliteclirc=None): self.config = self.init_config(sqliteclirc) self.prompt = prompt or self.config['main']['prompt'] or self.DEFAULT_PROMPT self.prompt_continuation = self.config['main']['prompt_continuation'] self.multi_line = self.config['main'].as_bool('multi_line') self.key_bindings = self.config['main']['key_bindings'] self.explicit_pager = False self.logfile = None # Init formatter self.formatter = TabularOutputFormatter( format_name=self.config['main']['table_format'] ) self.formatter.cli = self sql_format.register_new_formatter(self.formatter) # Init style self.syntax_style = self.config['main']['syntax_style'] self.cli_style = self.config['colors'] self.output_style = style_factory(self.syntax_style, self.cli_style) # Init completer. self.smart_completion = self.config['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, supported_formats=self.formatter.supported_formats, keyword_casing=self.config['main'].get('keyword_casing', 'auto') ) self._completer_lock = threading.Lock() self.completion_refresher = CompletionRefresher() # Register custom special commands self.register_special_commands() self.cli = None def init_config(self, sqliteclirc): # Order matters, the settings in later file will override those from # previously file config_files = [ os.path.join(PACKAGE_ROOT, 'sqliteclirc'), '/etc/sqliteclirc', sqliteclirc ] return read_config_files(config_files) def connect(self, filename=None): self.sqlexecute = SQLExecute(filename) def run_cli(self): self.iterations = 0 self.refresh_completions() history_file = os.path.expanduser( os.environ.get('SQLiteCLI_HISTFILE', '~/.sqlitecli-history') ) history = FileHistory(history_file) self.cli = self._build_cli(history) def one_iteration(): document = self.cli.run() special.set_expanded_output(False) try: document = self.handle_editor_command(self.cli, document) except RuntimeError as e: self.echo(str(e), err=True, fg='red') return if not document.text.strip(): return mutating = False try: special.write_tee(self.get_prompt(self.prompt) + document.text) successful = False start = time() res = self.sqlexecute.run(document.text) successful = True threshold = 1000 result_count = 0 for title, cur, headers, status in res: if (is_select(status) and cur and cur.rowcount > threshold): self.echo( 'The result set has more than {} rows.'.forma(threshold), fg='red' ) if not confirm('Do you want to continue?'): self.echo('Aborted!', err=True, fg='red') break formatted = self.format_output( title, cur, headers, special.is_expanded_output(), None ) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: pass except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") query = Query(document.text, successful, mutating) try: while True: one_iteration() self.iterations += 1 except EOFError: special.close_tee() def _build_cli(self, history): key_binding_manager = cli_bindings() def prompt_tokens(cli): prompt = self.get_prompt(self.prompt) if len(prompt) > self.MAX_LEN_PROMPT: prompt = self.get_prompt('\\d> ') return [(Token.Prompt, prompt)] def get_continuation_tokens(cli, width): prompt = self.get_prompt(self.prompt_continuation) token = ( Token.Continuation, ' ' * (width - len(prompt)) + prompt ) return [token] def show_suggestion_tip(): return self.iterations < 2 get_toolbar_tokens = create_toolbar_tokens_func( self.completion_refresher.is_refreshing, show_suggestion_tip) layout = create_prompt_layout( lexer=Lexer, multiline=True, get_prompt_tokens=prompt_tokens, get_continuation_tokens=get_continuation_tokens, get_bottom_toolbar_tokens=get_toolbar_tokens, display_completions_in_columns=self.config['main'].as_bool('wider_completion_menu'), extra_input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor(chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()) ], reserve_space_for_menu=self.get_reserved_space() ) with self._completer_lock: buf = CLIBuffer( always_multiline=self.multi_line, completer=self.completer, history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS application = Application( style=style_from_pygments(style_cls=self.output_style), layout=layout, buffer=buf, key_bindings_registry=key_binding_manager.registry, on_exit=AbortAction.RAISE_EXCEPTION, on_abort=AbortAction.RETRY, editing_mode=editing_mode, ignore_case=True) cli = CommandLineInterface( application=application, eventloop=create_eventloop()) return cli def get_prompt(self, string): sqlexecute = self.sqlexecute now = datetime.now() string = string.replace('\\u', '(none)') string = string.replace('\\h', '(none)') string = string.replace('\\d', '(none)') string = string.replace('\\t', 'sqlite') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\_', ' ') return string def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def register_special_commands(self): pass def handle_editor_command(self, cli, document): """ Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param cli: CommandLineInterface :param document: Document :return: Document """ # FIXME: using application.pre_run_callables like this here is not the best solution. # It's internal api of prompt_toolkit that may change. This was added to fix # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. saved_callables = cli.application.pre_run_callables while special.editor_command(document.text): filename = special.get_filename(document.text) query = (special.get_editor_query(document.text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) cli.current_buffer.document = Document(sql, cursor_position=len(sql)) cli.application.pre_run_callables = [] document = cli.run() continue cli.application.pre_run_callables = saved_callables return document def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: click.echo(utf8tounicode(output), file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt(self.prompt).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.cli.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled() for i, line in enumerate(output, 1): self.log_output(line) special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled(): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: self.log_output(status) click.secho(status) def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, {'smart_completion': self.smart_completion, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing}) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer # When cli is first launched we call refresh_completions before # instantiating the cli object. So it is necessary to check if cli # exists before trying the replace the completer object in cli. if self.cli: self.cli.current_buffer.completer = new_completer if self.cli: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.cli.request_redraw() def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result self.formatter.query = query output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'disable_numparse': True, 'preserve_whitespace': True, 'preprocessors': (preprocessors.align_decimals, ), 'style': self.output_style } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): def get_col_type(col): col_type = FIELD_TYPES.get(col[1], text_type) return col_type if type(col_type) is type else text_type column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if (not expanded and max_width and headers and cur and len(first_line) > max_width): formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
class PreviewElement: """ Class to create the preview element. It contains two main methods: create_container: creates the main preview container. Intention is for this to land in a float. create_completion_float: creates the completion float in the preview container. Intention is for this to appear in the FloatContainer that hosts the main preview container float. """ def __init__(self, my_app: "sqlApp"): self.my_app = my_app help_text = """ Press Enter in the input box to page through the table. Alternatively, enter a filtering SQL statement and then press Enter to page through the results. """ self.formatter = TabularOutputFormatter() self.completer = PreviewCompleter( my_app=self.my_app, completer=MssqlCompleter( smart_completion=True, get_conn=lambda: self.my_app.selected_object.conn)) history_file = config_location() + 'preview_history' ensure_dir_exists(history_file) hist = PreviewHistory(my_app=self.my_app, filename=expanduser(history_file)) self.input_buffer = PreviewBuffer( name="previewbuffer", tempfile_suffix=".sql", history=ThreadedHistory(hist), auto_suggest=ThreadedAutoSuggest( PreviewSuggestFromHistory(my_app)), completer=ThreadedCompleter(self.completer), # history = hist, # auto_suggest = PreviewSuggestFromHistory(my_app), # completer = self.completer, complete_while_typing=Condition( lambda: self.my_app.selected_object is not None and self.my_app .selected_object.conn.connected()), multiline=False) input_control = BufferControl( buffer=self.input_buffer, include_default_input_processors=False, input_processors=[AppendAutoSuggestion()], preview_search=False) self.input_window = Window(input_control) search_buffer = Buffer(name="previewsearchbuffer") self.search_field = SearchToolbar(search_buffer) self.output_field = TextArea( style="class:preview-output-field", text=help_text, height=D(preferred=50), search_field=self.search_field, wrap_lines=False, focusable=True, read_only=True, preview_search=True, input_processors=[ ConditionalProcessor( processor=HighlightIncrementalSearchProcessor(), filter=has_focus("previewsearchbuffer") | has_focus(self.search_field.control), ), HighlightSelectionProcessor(), ]) def refresh_results(window_height) -> bool: """ This method gets called when the app restarts after exiting for execution of preview query. It populates the output buffer with results from the fetch/query. """ sql_conn = self.my_app.selected_object.conn if sql_conn.execution_status == executionStatus.FAIL: # Let's display the error message to the user output = sql_conn.execution_err else: crsr = sql_conn.cursor if crsr.description: cols = [col.name for col in crsr.description] else: cols = [] if len(cols): res = sql_conn.fetch_from_cache(size=window_height - 4, wait=True) output = self.formatter.format_output(res, cols, format_name="psql") output = "\n".join(output) else: output = "No rows returned\n" # Add text to output buffer. self.output_field.buffer.set_document( Document(text=output, cursor_position=0), True) return True def accept(buff: Buffer) -> bool: """ This method gets called when the user presses enter/return in the filter box. It is interpreted as either 'execute query' or 'fetch next page of results' if filter query hasn't changed. """ obj = self.my_app.selected_object sql_conn = obj.conn identifier = object_to_identifier(obj) query = sql_conn.preview_query( name=identifier, obj_type=obj.otype, filter_query=buff.text, limit=self.my_app.preview_limit_rows) if query is None: return True func = partial(refresh_results, window_height=self.output_field.window.render_info. window_height) if sql_conn.query != query: # Exit the app to execute the query self.my_app.application.exit(result=["preview", query]) self.my_app.application.pre_run_callables.append(func) else: # No need to exit let's just go and fetch func() return True # Keep filter text def cancel_handler() -> None: sql_conn = self.my_app.selected_object.conn sql_conn.close_cursor() self.input_buffer.text = "" self.output_field.buffer.set_document( Document(text=help_text, cursor_position=0), True) self.my_app.show_preview = False self.my_app.show_sidebar = True self.my_app.application.layout.focus(self.input_buffer) self.my_app.application.layout.focus("sidebarbuffer") return None self.input_buffer.accept_handler = accept self.cancel_button = Button(text="Done", handler=cancel_handler) def create_completion_float(self) -> Float: return Float(xcursor=True, ycursor=True, transparent=True, attach_to_window=self.input_window, content=CompletionsMenu(scroll_offset=1, max_height=16, extra_filter=has_focus( self.input_buffer))) def create_container(self): container = HSplit([ Box(body=VSplit([self.input_window, self.cancel_button], padding=1), padding=1, style="class:preview-input-field"), Window(height=1, char="-", style="class:preview-divider-line"), self.output_field, self.search_field, ]) frame = Shadow(body=Frame( title=lambda: "Preview: " + self.my_app.selected_object.name, body=container, style="class:dialog.body", width=D(preferred=180, min=30), modal=True)) return ConditionalContainer(content=frame, filter=ShowPreview(self.my_app) & ~is_done)
class OCli(object): default_prompt = '\\t \\u@\\h:\\d> ' max_len_prompt = 45 defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. cnf_files = [ '~/login.conf', ] system_config_files = [] default_config_file = os.path.join(PACKAGE_ROOT, 'okclirc') def __init__(self, sqlexecute=None, prompt=None, logfile=None, defaults_suffix=None, defaults_file=None, login_path=None, auto_vertical_output=False, warn=None, okclirc="~/.okclirc"): self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path # self.cnf_files is a class variable that stores the list of oracle # config files to read in at launch. # If defaults_file is specified then override the class variable with # defaults_file. if defaults_file: self.cnf_files = [defaults_file] # Load config. config_files = ([self.default_config_file] + self.system_config_files + [okclirc]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) self.syntax_style = c['main']['syntax_style'] self.cli_style = c['colors'] self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') c_ddl_warning = c['main'].as_bool('ddl_warning') self.ddl_warning = c_ddl_warning if warn is None else warn self.login_path_as_host = c['main'].as_bool('login_path_as_host') # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or \ c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files: write_default_config(self.default_config_file, okclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: try: self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') except (IOError, OSError) as e: self.echo( 'Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red') self.logfile = False self.completion_refresher = CompletionRefresher() self.logger = logging.getLogger(__name__) self.initialize_logging() prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt self.prompt_continuation_format = c['main']['prompt_continuation'] self.query_history = [] # Initialize completer. self.smart_completion = c['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, supported_formats=self.formatter.supported_formats) self._completer_lock = threading.Lock() # Register custom special commands. self.register_special_commands() self.cli = None def register_special_commands(self): special.register_special_command(self.change_schema, 'use', 'use [schema]', 'Change to a new schema.', aliases=['\\u']) special.register_special_command( self.change_db, 'connect', 'connect [database]', 'Reconnect to the database. Optional database argument.', aliases=('\\r', ), case_sensitive=True) special.register_special_command(self.refresh_completions, 'refresh', 'refresh', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#', )) special.register_special_command( self.change_table_format, 'format', '\\T [format]', 'Change the format used to output results (html, csv etc.).', aliases=('\\T', ), case_sensitive=True) special.register_special_command(self.execute_from_file, '@', 'a [filename]', 'Execute commands from file.', aliases=['\\.', 'source']) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R', ), case_sensitive=True) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) def change_schema(self, arg, **_): if arg is None: yield (None, None, None, 'Specify the schema to switch to.') else: schema_name = str(arg).upper() self.sqlexecute.conn.current_schema = schema_name self.sqlexecute.dbname = schema_name yield (None, None, None, 'Schema updated to {}'.format(arg)) def execute_from_file(self, arg, **_): if not arg: message = 'Missing required argument, filename.' return [(None, None, None, message)] try: with open(os.path.expanduser(arg), encoding='utf-8') as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] if (self.ddl_warning and confirm_ddl_query(query) is False): message = 'Command execution stopped.' return [(None, None, None, message)] return self.sqlexecute.run(query) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): log_file = self.config['main']['log_file'] log_level = self.config['main']['log_level'] level_map = { 'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": handler = NullHandler() log_level = "CRITICAL" else: handler = logging.FileHandler(os.path.expanduser(log_file)) formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) root_logger = logging.getLogger('okcli') root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) root_logger.debug('Initializing okcli logging.') root_logger.debug('Log file %r.', log_file) def read_my_cnf_files(self, files, keys): """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ cnf = read_config_files(files) sections = ['client'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) def get(key): result = None for sect in cnf: if sect in sections and key in cnf[sect]: result = cnf[sect][key] return result return {x: get(x) for x in keys} def connect(self, database='', user='', passwd='', host=''): cnf = { 'database': None, 'user': None, 'password': None, 'host': None, } cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. database = database or cnf['database'] user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] or 'localhost' passwd = cnf['password'] or passwd if not passwd: passwd = click.prompt('Password', hide_input=True, show_default=False, type=str) # Assume connecting to schema with same name as user by default if not database: database = user.upper() # Connect to the database. try: from cx_Oracle import DatabaseError try: sqlexecute = SQLExecute(database, user, passwd, host) except DatabaseError as e: if ('invalid username/password' in str(e)): passwd = click.prompt('Password', hide_input=True, show_default=False, type=str) sqlexecute = SQLExecute(database, user, passwd, host) else: raise except Exception as e: # Connecting to a database could fail. self.logger.debug('Database connection failed', exc_info=True) self.echo(str(e), err=True, fg='red') exit(1) self.sqlexecute = sqlexecute def handle_editor_command(self, cli, document): """ Editor command is any query that is prefixed by a ed. The reason for a while loop is because a user might edit a query multiple times. For eg: r"select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param cli: CommandLineInterface :param document: Document :return: Document """ # FIXME: using application.pre_run_callables like this here is not the best solution. # It's internal api of prompt_toolkit that may change. This was added to fix # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. saved_callables = cli.application.pre_run_callables while special.editor_command(document.text): filename = special.get_filename(document.text) query = (special.get_editor_query(document.text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) cli.current_buffer.document = Document(sql, cursor_position=len(sql)) cli.application.pre_run_callables = [] document = cli.run() continue cli.application.pre_run_callables = saved_callables return document def run_cli(self): sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() if self.smart_completion: self.refresh_completions() key_binding_manager = okcli_bindings() def prompt_tokens(cli): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len( prompt) > self.max_len_prompt: prompt = self.get_prompt('\\d> ') return [(Token.Prompt, prompt)] def get_continuation_tokens(cli, width): continuation_prompt = self.get_prompt( self.prompt_continuation_format) return [(Token.Continuation, ' ' * (width - len(continuation_prompt)) + continuation_prompt)] def one_iteration(document=None): if document is None: document = self.cli.run() special.set_expanded_output(False) try: document = self.handle_editor_command(self.cli, document) except RuntimeError as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not document.text.strip(): return if self.ddl_warning: destroy = confirm_ddl_query(document.text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('OK') else: self.echo('Cancelled') return # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered # mutating if any one of the component statements is mutating mutating = False try: logger.debug('sql: %r', document.text) special.write_tee( self.get_prompt(self.prompt_format) + document.text) if self.logfile: self.logfile.write('\n# %s\n' % datetime.now()) self.logfile.write(document.text) self.logfile.write('\n') successful = False start = time() res = sqlexecute.run(document.text) successful = True result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 if (is_select(status) and cur and cur.rowcount > threshold): self.echo( 'The result set has more than {} rows.'.format( threshold), fg='red') if not click.confirm('Do you want to continue?'): self.echo("Aborted!", err=True, fg='red') break if self.auto_vertical_output: max_width = self.cli.output.get_size().columns else: max_width = None formatted = self.format_output( title, cur, headers, special.is_expanded_output(), max_width) if cur is not None: status = self.sqlexecute.get_status(cur) t = time() - start try: if result_count > 0: self.echo('') try: self.output('\n'.join(formatted), status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id logger.debug("connection id to kill: %r", connection_id_to_kill) # Restart connection to the database sqlexecute.connect() try: for title, cur, headers, status in sqlexecute.run( 'kill %s' % connection_id_to_kill): status_str = str(status).lower() if status_str.find('ok') > -1: logger.debug( "cancelled query, connection id: %r, sql: %r", connection_id_to_kill, document.text) self.echo("cancelled query", err=True, fg='red') except Exception as e: self.echo( 'Encountered error while cancelling query: {}'.format( e), err=True, fg='red') except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except Exception as e: logger.debug("Error", exc_info=True) if (e.args[0] in (2003, 2006, 2013)): logger.debug('Attempting to reconnect.') self.echo('Reconnecting...', fg='yellow') try: sqlexecute.connect() logger.debug('Reconnected successfully.') one_iteration(document) return # OK to just return, cuz the recursion call runs to the end. except Exception as e: logger.debug('Reconnect failed', exc_info=True) self.echo(str(e), err=True, fg='red') # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: logger.error("sql: %r, error: %r", document.text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: # Refresh the table names and column names if necessary. if need_completion_refresh(document.text): self.refresh_completions( reset=need_completion_reset(document.text)) finally: if self.logfile is False: self.echo("Warning: This query was not logged.", err=True, fg='red') query = Query(document.text, successful, mutating) self.query_history.append(query) get_toolbar_tokens = create_toolbar_tokens_func( self.completion_refresher.is_refreshing) layout = create_prompt_layout( lexer=OracleLexer, multiline=True, get_prompt_tokens=prompt_tokens, get_continuation_tokens=get_continuation_tokens, get_bottom_toolbar_tokens=get_toolbar_tokens, display_completions_in_columns=self.wider_completion_menu, extra_input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor( chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()) ], reserve_space_for_menu=self.get_reserved_space()) with self._completer_lock: buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer, history=FileHistory( os.path.expanduser( os.environ.get('okcli_HISTFILE', '~/.okcli-history'))), complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS application = Application( style=style_factory(self.syntax_style, self.cli_style), layout=layout, buffer=buf, key_bindings_registry=key_binding_manager.registry, on_exit=AbortAction.RAISE_EXCEPTION, on_abort=AbortAction.RETRY, editing_mode=editing_mode, ignore_case=True) self.cli = CommandLineInterface(application=application, eventloop=create_eventloop()) try: while True: one_iteration() except EOFError: special.close_tee() def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: self.logfile.write(utf8tounicode(output)) self.logfile.write('\n') def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def output_fits_on_screen(self, output, status=None): """Check if the given output fits on the screen.""" size = self.cli.output.get_size() margin = self.get_reserved_space() + self.get_prompt( self.prompt_format).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') for i, line in enumerate(output.splitlines(), 1): if len(line) > size.columns or i > (size.rows - margin): return False return True def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: self.log_output(output) special.write_tee(output) special.write_once(output) if (self.explicit_pager or (special.is_pager_enabled() and not self.output_fits_on_screen(output, status))): click.echo_via_pager(output) else: click.secho(output) if status: self.log_output(status) click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. if not os.environ.get('LESS'): os.environ['LESS'] = '-RXF' cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) if cnf['pager']: special.set_pager(cnf['pager']) self.explicit_pager = True else: self.explicit_pager = False if cnf['skip-pager']: special.disable_pager() def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, { 'smart_completion': self.smart_completion, 'supported_formats': self.formatter.supported_formats }) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): self._swap_completer_objects(new_completer) if self.cli: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.cli.request_redraw() def _swap_completer_objects(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer # When okcli is first launched we call refresh_completions before # instantiating the cli object. So it is necessary to check if cli # exists before trying the replace the completer object in cli. if self.cli: self.cli.current_buffer.completer = new_completer def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') string = string.replace('\\t', sqlexecute.server_type()[0] or 'okcli') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) return string def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] if title: # Only print the title if it's not None. output.append(title) if cur: rows = list(cur) formatted = self.formatter.format_output( rows, headers, format_name='vertical' if expanded else None) if (not expanded and max_width and rows and content_exceeds_width(rows[0], max_width) and headers): formatted = self.formatter.format_output( rows, headers, format_name='vertical') output.append(formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(round(height * reserved_space_ratio), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None
class AthenaCli(object): DEFAULT_PROMPT = '\\d@\\r> ' MAX_LEN_PROMPT = 45 def __init__(self, region, aws_access_key_id, aws_secret_access_key, aws_account_id, athenaclirc, profile, scan): config_files = (DEFAULT_CONFIG_FILE, athenaclirc) _cfg = self.config = read_config_files(config_files) self.init_logging(_cfg['main']['log_file'], _cfg['main']['log_level']) self.aws_config = AWSConfig(aws_access_key_id, aws_secret_access_key, region, aws_account_id, profile, _cfg) if scan: LOGGER.info("Scanning Lake Formation Permissions") scanner = Scanner(self.aws_config, _cfg['main']['iam_db_path']) scanner.scan() try: self.connect(_cfg['main']['iam_db_path']) except Exception as e: self.echo(str(e), err=True, fg='red') err_msg = ''' There was an error while connecting to AWS Athena. It could be caused due to missing/incomplete configuration. Please verify the configuration in %s and run lakecli again. For more details about the error, you can check the log file: %s''' % ( LAKE_CLI_RC, _cfg['main']['log_file']) self.echo(err_msg) LOGGER.exception('error: %r', e) sys.exit(1) special.set_timing_enabled(_cfg['main'].as_bool('timing')) self.multi_line = _cfg['main'].as_bool('multi_line') self.key_bindings = _cfg['main']['key_bindings'] self.prompt = _cfg['main']['prompt'] or self.DEFAULT_PROMPT self.destructive_warning = _cfg['main']['destructive_warning'] self.syntax_style = _cfg['main']['syntax_style'] self.prompt_continuation_format = _cfg['main']['prompt_continuation'] self.formatter = TabularOutputFormatter(_cfg['main']['table_format']) self.formatter.cli = self sql_format.register_new_formatter(self.formatter) self.output_style = style_factory(self.syntax_style, _cfg['colors']) self.completer = AthenaCompleter() self._completer_lock = threading.Lock() self.completion_refresher = CompletionRefresher() self.cli = None self.query_history = [] # Register custom special commands. self.register_special_commands() def init_logging(self, log_file, log_level_str): file_path = os.path.expanduser(log_file) if not os.path.exists(file_path): mkdir_p(os.path.dirname(file_path)) handler = logging.FileHandler(os.path.expanduser(log_file)) log_level_map = { 'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG, } log_level = log_level_map[log_level_str.upper()] formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) LOGGER.addHandler(handler) LOGGER.setLevel(log_level) root_logger = logging.getLogger('lakecli') root_logger.addHandler(handler) root_logger.setLevel(log_level) root_logger.debug('Initializing lakecli logging.') root_logger.debug('Log file %r.', log_file) pgspecial_logger = logging.getLogger('special') pgspecial_logger.addHandler(handler) pgspecial_logger.setLevel(log_level) def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u', )) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R', ), case_sensitive=True) special.register_special_command( self.change_table_format, 'tableformat', '\\T', 'Change the table format used to output results.', aliases=('\\T', ), case_sensitive=True) from lakecli.packages.special.main import RAW_QUERY special.register_special_command(self.grant_execute, 'GRANT', '', 'Execute GRANT Statement', arg_type=RAW_QUERY) special.register_special_command(self.revoke_execute, 'REVOKE', '', 'Execute REVOKE Statement', arg_type=RAW_QUERY) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if arg is None: self.sqlexecute.connect() else: self.sqlexecute.connect(database=arg) yield (None, None, None, 'You are now connected to database "%s"' % self.sqlexecute.database) def grant_execute(self, cur, query): LOGGER.debug(query) from lakecli.privileges.grant_or_revoke import Grant grant = Grant(self.aws_config, sqlparse.parse(query)[0]) try: grant.process() grant.execute() yield (None, None, None, 'GRANT') except RuntimeError as err: yield (None, None, None, err) def revoke_execute(self, cur, query): LOGGER.debug(query) from lakecli.privileges.grant_or_revoke import Revoke revoke = Revoke(self.aws_config, sqlparse.parse(query)[0]) try: revoke.process() revoke.execute() yield (None, None, None, 'REVOKE') except RuntimeError as err: yield (None, None, None, err) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def connect(self, path): self.sqlexecute = SQLExecute(path) def handle_editor_command(self, cli, document): """ Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e"<enter> to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param cli: CommandLineInterface :param document: Document :return: Document """ # FIXME: using application.pre_run_callables like this here is not the best solution. # It's internal api of prompt_toolkit that may change. This was added to fix # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. saved_callables = cli.application.pre_run_callables while special.editor_command(document.text): filename = special.get_filename(document.text) query = (special.get_editor_query(document.text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) cli.current_buffer.document = Document(sql, cursor_position=len(sql)) cli.application.pre_run_callables = [] document = cli.run() continue cli.application.pre_run_callables = saved_callables return document def run_query(self, query, new_line=True): """Runs *query*.""" if (self.destructive_warning and confirm_destructive_query(query) is False): message = 'Wise choice. Command execution stopped.' click.echo(message) return results = self.sqlexecute.run(query) for result in results: title, rows, headers, _ = result self.formatter.query = query output = self.format_output(title, rows, headers) for line in output: click.echo(line, nl=new_line) def run_cli(self): self.iterations = 0 self.configure_pager() self.refresh_completions() history_file = os.path.expanduser(self.config['main']['history_file']) history = FileHistory(history_file) self.cli = self._build_cli(history) def one_iteration(): document = self.cli.run() special.set_expanded_output(False) try: document = self.handle_editor_command(self.cli, document) except RuntimeError as e: LOGGER.error("sql: %r, error: %r", document.text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not document.text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(document.text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('Your call!') else: self.echo('Wise choice!') return mutating = False try: LOGGER.debug('sql: %r', document.text) special.write_tee(self.get_prompt(self.prompt) + document.text) successful = False start = time() res = self.sqlexecute.run(document.text) successful = True threshold = 1000 result_count = 0 for title, rows, headers, status in res: if rows and len(rows) > threshold: self.echo( 'The result set has more than {} rows.'.format( threshold), fg='red') if not confirm('Do you want to continue?'): self.echo('Aborted!', err=True, fg='red') break formatted = self.format_output( title, rows, headers, special.is_expanded_output(), None) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or is_mutating(status) special.unset_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: pass except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: LOGGER.debug("Exception: %r", e) LOGGER.error("sql: %r, error: %r", document.text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: LOGGER.error("sql: %r, error: %r", document.text, e) LOGGER.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: # Refresh the table names and column names if necessary. if need_completion_refresh(document.text): LOGGER.debug("=" * 10) self.refresh_completions() query = Query(document.text, successful, mutating) self.query_history.append(query) try: while True: one_iteration() self.iterations += 1 except EOFError: special.close_tee() def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt( self.prompt).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.cli.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled( ) for i, line in enumerate(output, 1): special.write_tee(line) special.write_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled( ): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for line in buf: click.secho(line) buf = [] else: click.secho(line) if buf: if output_via_pager: # sadly click.echo_via_pager doesn't accept generators click.echo_via_pager("\n".join(buf)) else: for line in buf: click.secho(line) if status: click.secho(status) def configure_pager(self): self.explicit_pager = False if not self.config['main'].as_bool('enable_pager'): special.disable_pager() def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'disable_numparse': True, 'preserve_whitespace': True, 'preprocessors': (preprocessors.align_decimals, ), 'style': self.output_style } if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): def get_col_type(col): col_type = text_type return col_type if type(col_type) is type else text_type column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if (not expanded and max_width and headers and cur and len(first_line) > max_width): formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, (text_type)): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) return output def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ click.secho(s, **kwargs) def refresh_completions(self): with self._completer_lock: self.completer.reset_completions() completer_options = { 'smart_completion': True, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing } self.completion_refresher.refresh(self.sqlexecute, self._on_completions_refreshed, completer_options) def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer # When cli is first launched we call refresh_completions before # instantiating the cli object. So it is necessary to check if cli # exists before trying the replace the completer object in cli. if self.cli: self.cli.current_buffer.completer = new_completer if self.cli: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.cli.request_redraw() def _build_cli(self, history): key_binding_manager = cli_bindings() def prompt_tokens(cli): prompt = self.get_prompt(self.prompt) if len(prompt) > self.MAX_LEN_PROMPT: prompt = self.get_prompt('\\r:\\d> ') return [(Token.Prompt, prompt)] def get_continuation_tokens(cli, width): prompt = self.get_prompt(self.prompt_continuation_format) token = (Token.Continuation, ' ' * (width - len(prompt)) + prompt) return [token] def show_suggestion_tip(): return self.iterations < 2 get_toolbar_tokens = create_toolbar_tokens_func( self.completion_refresher.is_refreshing, show_suggestion_tip) layout = create_prompt_layout( lexer=Lexer, multiline=True, get_prompt_tokens=prompt_tokens, get_continuation_tokens=get_continuation_tokens, get_bottom_toolbar_tokens=get_toolbar_tokens, display_completions_in_columns=False, extra_input_processors=[ ConditionalProcessor( processor=HighlightMatchingBracketProcessor( chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()) ], reserve_space_for_menu=self.get_reserved_space()) with self._completer_lock: buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer, history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS application = Application( style=style_from_pygments(style_cls=self.output_style), layout=layout, buffer=buf, key_bindings_registry=key_binding_manager.registry, on_exit=AbortAction.RAISE_EXCEPTION, on_abort=AbortAction.RETRY, editing_mode=editing_mode, ignore_case=True) cli = CommandLineInterface(application=application, eventloop=create_eventloop()) return cli def get_prompt(self, string): sqlexecute = self.sqlexecute now = datetime.now() string = string.replace('\\d', os.path.basename(sqlexecute.path) or '(none)') string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\s', now.strftime('%S')) return string def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = click.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None