Example #1
0
 def run(self):
     eventloop = create_eventloop()
     try:
         cli = CommandLineInterface(application=self.application,
                                    eventloop=eventloop)
         cli.run()
     finally:
         eventloop.close()
Example #2
0
class AqPrompt(object):
    def __init__(self, parser, engine, options=None):
        self.parser = parser
        self.engine = engine
        self.options = options if options is not None else {}
        util.ensure_data_dir_exists()
        application = create_prompt_application(
            message='> ',
            lexer=PygmentsLexer(SqlLexer),
            history=FileHistory(os.path.expanduser('~/.aq/history')),
            completer=AqCompleter(schemas=engine.available_schemas, tables=engine.available_tables),
            auto_suggest=AutoSuggestFromHistory(),
            validator=QueryValidator(parser),
            on_abort=AbortAction.RETRY,
        )
        loop = create_eventloop()
        self.cli = CommandLineInterface(application=application, eventloop=loop)
        self.patch_context = self.cli.patch_stdout_context()

    def prompt(self):
        with self.patch_context:
            return self.cli.run(reset_current_buffer=True).text

    def update_with_result(self, query_metadata):
        # TODO
        pass
Example #3
0
def run():
    validate_osenvironment()
    print_banner()

    cli_buffer = OSBuffer()
    ltobj = OSLayout(multiwindow=False)

    application = Application(style=PygmentsStyle(OSStyle),
                              layout=ltobj.layout,
                              buffers=cli_buffer.buffers,
                              on_exit=AbortAction.RAISE_EXCEPTION,
                              key_bindings_registry=OSKeyBinder.registry)

    cli = CommandLineInterface(application=application,
                               eventloop=create_eventloop())

    while True:
        try:
            document = cli.run(reset_current_buffer=True)
            process_document(document)
        except KeyboardInterrupt:
            # A keyboardInterrupt generated possibly due to Ctrl-C
            print "Keyboard Interrupt Generated"
            continue
        except EOFError:
            print "cntl-D"
            sys.exit()
Example #4
0
class AqPrompt(object):
    def __init__(self, parser, engine, options=None):
        self.parser = parser
        self.engine = engine
        self.options = options if options is not None else {}
        util.ensure_data_dir_exists()
        application = create_prompt_application(
            message='> ',
            lexer=PygmentsLexer(SqlLexer),
            history=FileHistory(os.path.expanduser('~/.aq/history')),
            completer=AqCompleter(schemas=engine.available_schemas,
                                  tables=engine.available_tables),
            auto_suggest=AutoSuggestFromHistory(),
            validator=QueryValidator(parser),
            on_abort=AbortAction.RETRY,
        )
        loop = create_eventloop()
        self.cli = CommandLineInterface(application=application,
                                        eventloop=loop)
        self.patch_context = self.cli.patch_stdout_context()

    def prompt(self):
        with self.patch_context:
            return self.cli.run(reset_current_buffer=True).text

    def update_with_result(self, query_metadata):
        # TODO
        pass
Example #5
0
    def run(self):
      labels = self.neo4j.get_labels()
      relationship_types = self.neo4j.get_relationship_types()
      properties = self.neo4j.get_property_keys()

      if self.filename:
        with open(self.filename, "rb") as f:
          queries = split_queries_on_semicolons(f.read())

          for query in queries:
            print("> " + query)
            self.handle_query(query)
            print()

          return

      click.secho(" ______     __  __     ______     __         __    ", fg="red")
      click.secho("/\  ___\   /\ \_\ \   /\  ___\   /\ \       /\ \   ", fg="yellow")
      click.secho("\ \ \____  \ \____ \  \ \ \____  \ \ \____  \ \ \  ", fg="green")
      click.secho(" \ \_____\  \/\_____\  \ \_____\  \ \_____\  \ \_\ ", fg="blue")
      click.secho("  \/_____/   \/_____/   \/_____/   \/_____/   \/_/ ", fg="magenta")

      print("Cycli version: {}".format(__version__))
      print("Neo4j version: {}".format(".".join(map(str, self.neo4j.neo4j_version))))
      print("Bug reports: https://github.com/nicolewhite/cycli/issues\n")

      completer = CypherCompleter(labels, relationship_types, properties)

      layout = create_prompt_layout(
        lexer=CypherLexer,
        get_prompt_tokens=get_tokens,
        reserve_space_for_menu=8,
      )

      buff = CypherBuffer(
        accept_action=AcceptAction.RETURN_DOCUMENT,
        history=FileHistory(filename=os.path.expanduser('~/.cycli_history')),
        completer=completer,
        complete_while_typing=True,
      )

      application = Application(
        style=PygmentsStyle(CypherStyle),
        buffer=buff,
        layout=layout,
        on_exit=AbortAction.RAISE_EXCEPTION,
        key_bindings_registry=CypherBinder.registry
      )

      cli = CommandLineInterface(application=application, eventloop=create_eventloop())

      try:
        while True:
          document = cli.run()
          query = document.text
          self.handle_query(query)
      except UserWantsOut:
        print("Goodbye!")
      except Exception as e:
        print(e)
Example #6
0
    def run(self):
        neo4j = Neo4j(self.host, self.port, self.username, self.password)

        try:
            labels = neo4j.labels()
            relationship_types = neo4j.relationship_types()
            properties = neo4j.properties()

        except Unauthorized:
            print("Unauthorized. See cycli --help for authorization instructions.")
            return

        except SocketError:
            print("Connection refused. Is Neo4j turned on?")
            return

        completer = CypherCompleter(labels, relationship_types, properties)

        layout = create_default_layout(
            lexer=CypherLexer,
            get_prompt_tokens=get_tokens,
            reserve_space_for_menu=True
        )

        buff = CypherBuffer(
            history=History(),
            completer=completer,
            complete_while_typing=Always()
        )

        application = Application(
            style=CypherStyle,
            buffer=buff,
            layout=layout,
            on_exit=AbortAction.RAISE_EXCEPTION
        )

        cli = CommandLineInterface(application=application, eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()
                query = document.text

                if query in ["quit", "exit"]:
                    raise Exception

                elif query == "help":
                    print(help_text())

                else:
                    results = neo4j.cypher(query)
                    print(results)

        except Exception:
            print("Goodbye!")
Example #7
0
class SaltCli(object):
    """
    The CLI implementation.
    """
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.id_completer = WordCompleter([
            'id1', 'id2', 'id3'
        ])


    def run_cli(self):
        """
        Run the main loop
        """
        print(u'Version:', __version__)
        print(u'Home: https://github.com/glasslion/saltcli')

        history = FileHistory(os.path.expanduser('~/.saltcli-history'))

        layout = create_prompt_layout(
            message=u'saltcli> ',
        )

        application = Application(
            layout=layout
        )

        eventloop = create_eventloop()

        self.cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

        while True:
            try:
                document = self.cli.run()

                if quit_command(document.text):
                    raise EOFError

            except KeyboardInterrupt:
                # user pressed Ctrl + C
                    click.echo('')
            except EOFError:
                break
            except Exception as ex:
                self.logger.debug('Exception: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho("{0}".format(ex), fg='red')
                break

        print('Goodbye!')
Example #8
0
def loop(cmd, history_file):
    key_binding_manager = KeyBindingManager(
        enable_search=True,
        enable_abort_and_exit_bindings=True
    )
    layout = create_prompt_layout(
        message=u'cr> ',
        multiline=True,
        lexer=SqlLexer,
        extra_input_processors=[
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ]
    )
    buffer = CrashBuffer(
        history=TruncatedFileHistory(history_file, max_length=MAX_HISTORY_LENGTH),
        accept_action=AcceptAction.RETURN_DOCUMENT,
        completer=SQLCompleter(cmd)
    )
    buffer.complete_while_typing = lambda cli=None: cmd.should_autocomplete()
    application = Application(
        layout=layout,
        buffer=buffer,
        style=PygmentsStyle.from_defaults(pygments_style_cls=CrateStyle),
        key_bindings_registry=key_binding_manager.registry,
        editing_mode=_get_editing_mode(),
        on_exit=AbortAction.RAISE_EXCEPTION,
        on_abort=AbortAction.RETRY,
    )
    eventloop = create_eventloop()
    output = create_output()
    cli = CommandLineInterface(
        application=application,
        eventloop=eventloop,
        output=output
    )

    def get_num_columns_override():
        return output.get_size().columns
    cmd.get_num_columns = get_num_columns_override

    while True:
        try:
            doc = cli.run(reset_current_buffer=True)
            if doc:
                cmd.process(doc.text)
        except KeyboardInterrupt:
            cmd.logger.warn("Query not cancelled. Run KILL <jobId> to cancel it")
        except EOFError:
            cmd.logger.warn(u'Bye!')
            return
Example #9
0
class Prompt(object):  #pragma: no cover
    def __init__(self, renv):
        self.renv = renv
        self.is_long = True
        self.cli = None

    def initialize(self):
        history = InMemoryHistory()
        toolbar_handler = create_toolbar_handler(self.get_long_options)

        layout = create_prompt_layout(
            get_prompt_tokens=self.get_prompt_tokens,
            lexer=create_lexer(),
            get_bottom_toolbar_tokens=toolbar_handler)

        buf = Buffer(history=history,
                     completer=CrutchCompleter(self.renv),
                     complete_while_typing=Always(),
                     accept_action=AcceptAction.RETURN_DOCUMENT)

        manager = get_key_manager(self.set_long_options, self.get_long_options)

        application = Application(style=style_factory(),
                                  layout=layout,
                                  buffer=buf,
                                  key_bindings_registry=manager.registry,
                                  on_exit=AbortAction.RAISE_EXCEPTION,
                                  on_abort=AbortAction.RETRY,
                                  ignore_case=True)

        eventloop = create_eventloop()

        self.cli = CommandLineInterface(application=application,
                                        eventloop=eventloop)

    def get_prompt_tokens(self, _):
        return [(Token.Pound, ' Y '), (Token.Text, ' ')]

    def set_long_options(self, value):
        self.is_long = value

    def get_long_options(self):
        return self.is_long

    def activate(self, reinitialize=False):
        if reinitialize:
            self.initialize()
        assert self.cli
        document = self.cli.run(True)
        return shlex.split(document.text)
Example #10
0
File: repl.py Project: boseca/crash
def loop(cmd, history_file):
    key_binding_manager = KeyBindingManager(
        enable_search=True, enable_abort_and_exit_bindings=True)
    layout = create_prompt_layout(
        message=u'cr> ',
        multiline=True,
        lexer=SqlLexer,
        extra_input_processors=[
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ])
    buffer = CrashBuffer(history=TruncatedFileHistory(
        history_file, max_length=MAX_HISTORY_LENGTH),
                         accept_action=AcceptAction.RETURN_DOCUMENT,
                         completer=SQLCompleter(cmd))
    buffer.complete_while_typing = lambda cli=None: cmd.should_autocomplete()
    application = Application(
        layout=layout,
        buffer=buffer,
        style=PygmentsStyle.from_defaults(pygments_style_cls=CrateStyle),
        key_bindings_registry=key_binding_manager.registry,
        editing_mode=_get_editing_mode(),
        on_exit=AbortAction.RAISE_EXCEPTION,
        on_abort=AbortAction.RETRY,
    )
    eventloop = create_eventloop()
    output = create_output()
    cli = CommandLineInterface(application=application,
                               eventloop=eventloop,
                               output=output)

    def get_num_columns_override():
        return output.get_size().columns

    cmd.get_num_columns = get_num_columns_override

    while True:
        try:
            doc = cli.run(reset_current_buffer=True)
            if doc:
                cmd.process(doc.text)
        except KeyboardInterrupt:
            cmd.logger.warn(
                "Query not cancelled. Run KILL <jobId> to cancel it")
        except EOFError:
            cmd.logger.warn(u'Bye!')
            return
Example #11
0
def run():
    cli_buffer = VaultBuffer()
    vault_layout = VaultLayout(multiwindow=True)

    application = Application(style=PygmentsStyle(VaultStyle),
                              layout=vault_layout.layout,
                              buffers=cli_buffer.buffers,
                              on_exit=AbortAction.RAISE_EXCEPTION,
                              key_bindings_registry=VaultKeyBinder.registry,
                              use_alternate_screen=False)

    cli = CommandLineInterface(application=application,
                               eventloop=create_eventloop())

    while True:
        try:
            document = cli.run(reset_current_buffer=True)
            process_document(document)
        except KeyboardInterrupt:
            print "Keyboard interrupt generated"
            continue
        except EOFError:
            print "ctrl-D"
            sys.exit()
Example #12
0
class Saws(object):
    """Encapsulates the Saws CLI.

    Attributes:
        * aws_cli: An instance of prompt_toolkit's CommandLineInterface.
        * key_manager: An instance of KeyManager.
        * config: An instance of Config.
        * config_obj: An instance of ConfigObj, reads from ~/.sawsrc.
        * theme: A string representing the lexer theme.
        * logger: An instance of SawsLogger.
        * all_commands: A list of all commands, sub_commands, options, etc
            from data/SOURCES.txt.
        * commands: A list of commands from data/SOURCES.txt.
        * sub_commands: A list of sub_commands from data/SOURCES.txt.
        * completer: An instance of AwsCompleter.
    """

    PYGMENTS_CMD = ' | pygmentize -l json'

    def __init__(self, refresh_resources=True):
        """Inits Saws.

        Args:
            * refresh_resources: A boolean that determines whether to
                refresh resources.

        Returns:
            None.
        """
        self.aws_cli = None
        self.key_manager = None
        self.config = Config()
        self.config_obj = self.config.read_configuration()
        self.theme = self.config_obj[self.config.MAIN][self.config.THEME]
        self.logger = SawsLogger(
            __name__,
            self.config_obj[self.config.MAIN][self.config.LOG_FILE],
            self.config_obj[self.config.MAIN][self.config.LOG_LEVEL]).logger
        self.all_commands = AwsCommands().all_commands
        self.commands = \
            self.all_commands[AwsCommands.CommandType.COMMANDS.value]
        self.sub_commands = \
            self.all_commands[AwsCommands.CommandType.SUB_COMMANDS.value]
        self.completer = AwsCompleter(
            awscli_completer,
            self.all_commands,
            self.config,
            self.config_obj,
            self.log_exception,
            fuzzy_match=self.get_fuzzy_match(),
            shortcut_match=self.get_shortcut_match())
        if refresh_resources:
            self.completer.refresh_resources_and_options()
        self._create_cli()

    def log_exception(self, e, traceback, echo=False):
        """Logs the exception and traceback to the log file ~/.saws.log.

        Args:
            * e: A Exception that specifies the exception.
            * traceback: A Traceback that specifies the traceback.
            * echo: A boolean that specifies whether to echo the exception
                to the console using click.

        Returns:
            None.
        """
        self.logger.debug('exception: %r.', str(e))
        self.logger.error("traceback: %r", traceback.format_exc())
        if echo:
            click.secho(str(e), fg='red')

    def set_color(self, color):
        """Setter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the color flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.COLOR] = color

    def get_color(self):
        """Getter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the color flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.COLOR)

    def set_fuzzy_match(self, fuzzy):
        """Setter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the fuzzy flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.FUZZY] = fuzzy
        self.completer.fuzzy_match = fuzzy

    def get_fuzzy_match(self):
        """Getter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the fuzzy flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.FUZZY)

    def set_shortcut_match(self, shortcut):
        """Setter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the shortcut flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.SHORTCUT] = shortcut
        self.completer.shortcut_match = shortcut

    def get_shortcut_match(self):
        """Getter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the shortcut flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.SHORTCUT)

    def refresh_resources_and_options(self):
        """Convenience function to refresh resources and options for completion.

        Used by prompt_toolkit's KeyBindingManager.

        Args:
            * None.

        Returns:
            None.
        """
        self.completer.refresh_resources_and_options(force_refresh=True)

    def handle_docs(self, text=None, from_fkey=False):
        """Displays contextual web docs for `F9` or the `docs` command.

        Displays the web docs specific to the currently entered:

        * (optional) command
        * (optional) subcommand

        If no command or subcommand is present, the docs index page is shown.

        Docs are only displayed if:

        * from_fkey is True
        * from_fkey is False and `docs` is found in text

        Args:
            * text: A string representing the input command text.
            * from_fkey: A boolean representing whether this function is
                being executed from an `F9` key press.

        Returns:
            A boolean representing whether the web docs were shown.
        """
        base_url = 'http://docs.aws.amazon.com/cli/latest/reference/'
        index_html = 'index.html'
        if text is None:
            text = self.aws_cli.current_buffer.document.text
        # If the user hit the F9 key, append 'docs' to the text
        if from_fkey:
            text = text.strip() + ' ' + AwsCommands.AWS_DOCS
        tokens = text.split()
        if len(tokens) > 2 and tokens[-1] == AwsCommands.AWS_DOCS:
            prev_word = tokens[-2]
            # If we have a command, build the url
            if prev_word in self.commands:
                prev_word = prev_word + '/'
                url = base_url + prev_word + index_html
                webbrowser.open(url)
                return True
            # if we have a command and subcommand, build the url
            elif prev_word in self.sub_commands:
                command_url = tokens[-3] + '/'
                sub_command_url = tokens[-2] + '.html'
                url = base_url + command_url + sub_command_url
                webbrowser.open(url)
                return True
            webbrowser.open(base_url + index_html)
        # If we still haven't opened the help doc at this point and the
        # user hit the F9 key or typed docs, just open the main docs index
        if from_fkey or AwsCommands.AWS_DOCS in tokens:
            webbrowser.open(base_url + index_html)
            return True
        return False

    def _handle_cd(self, text):
        """Handles a `cd` shell command by calling python's os.chdir.

        Simply passing in the `cd` command to subprocess.call doesn't work.
        Note: Changing the directory within Saws will only be in effect while
        running Saws.  Exiting the program will return you to the directory
        you were in prior to running Saws.

        Attributes:
            * text: A string representing the input command text.

        Returns:
            A boolean representing a `cd` command was found and handled.
        """
        CD_CMD = 'cd'
        stripped_text = text.strip()
        if stripped_text.startswith(CD_CMD):
            directory = ''
            if stripped_text == CD_CMD:
                # Treat `cd` as a change to the root directory.
                # os.path.expanduser does this in a cross platform manner.
                directory = os.path.expanduser('~')
            else:
                tokens = text.split(CD_CMD + ' ')
                directory = tokens[-1]
            try:
                os.chdir(directory)
            except OSError as e:
                self.log_exception(e, traceback, echo=True)
            return True
        return False

    def _colorize_output(self, text):
        """Highlights output with pygments.

        Only highlights the output if all of the following conditions are True:

        * The color option is enabled
        * The text does not contain the `configure` command
        * The text does not contain the `help` command, which already does
            output highlighting

        Args:
            * text: A string that represents the input command text.

        Returns:
            A string that represents:
                * The original command text if no highlighting was performed.
                * The pygments highlighted command text otherwise.
        """
        stripped_text = text.strip()
        if not self.get_color() or stripped_text == '':
            return text
        excludes = [AwsCommands.AWS_CONFIGURE,
                    AwsCommands.AWS_HELP,
                    '|']
        if not any(substring in stripped_text for substring in excludes):
            return text.strip() + self.PYGMENTS_CMD
        else:
            return text

    def _handle_keyboard_interrupt(self, e, platform):
        """Handles keyboard interrupts more gracefully on Mac/Unix/Linux.

        Allows Mac/Unix/Linux to continue running on keyboard interrupt,
        as the user might interrupt a long-running AWS command with Control-C
        while continuing to work with Saws.

        On Windows, the "Terminate batch job (Y/N)" confirmation makes it
        tricky to handle this gracefully.  Thus, we re-raise KeyboardInterrupt.

        Args:
            * e: A KeyboardInterrupt.
            * platform: A string that denotes platform such as
                'Windows', 'Darwin', etc.

        Returns:
            None

        Raises:
            Exception: A KeyboardInterrupt if running on Windows.
        """
        if platform == 'Windows':
            raise e
        else:
            # Clear the renderer and send a carriage return
            self.aws_cli.renderer.clear()
            self.aws_cli.input_processor.feed_key(KeyPress(Keys.ControlM, ''))

    def _process_command(self, text):
        """Processes the input command, called by the cli event loop

        Args:
            * text: A string that represents the input command text.

        Returns:
            None.
        """
        if AwsCommands.AWS_COMMAND in text:
            text = self.completer.replace_shortcut(text)
            if self.handle_docs(text):
                return
        try:
            if not self._handle_cd(text):
                text = self._colorize_output(text)
                # Pass the command onto the shell so aws-cli can execute it
                subprocess.call(text, shell=True)
            print('')
        except KeyboardInterrupt as e:
            self._handle_keyboard_interrupt(e, platform.system())
        except Exception as e:
            self.log_exception(e, traceback, echo=True)

    def _create_cli(self):
        """Creates the prompt_toolkit's CommandLineInterface.

        Args:
            * None.

        Returns:
            None.
        """
        history = FileHistory(os.path.expanduser('~/.saws-history'))
        toolbar = Toolbar(self.get_color,
                          self.get_fuzzy_match,
                          self.get_shortcut_match)
        layout = create_default_layout(
            message='saws> ',
            reserve_space_for_menu=True,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar.handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )
        cli_buffer = Buffer(
            history=history,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)
        self.key_manager = KeyManager(
            self.set_color,
            self.get_color,
            self.set_fuzzy_match,
            self.get_fuzzy_match,
            self.set_shortcut_match,
            self.get_shortcut_match,
            self.refresh_resources_and_options,
            self.handle_docs)
        style_factory = StyleFactory(self.theme)
        application = Application(
            mouse_support=False,
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)
        eventloop = create_eventloop()
        self.aws_cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def run_cli(self):
        """Runs the main loop.

        Args:
            * None.

        Returns:
            None.
        """
        print('Version:', __version__)
        print('Theme:', self.theme)
        while True:
            document = self.aws_cli.run()
            self._process_command(document.text)
Example #13
0
    def run_cli(self):
        pgexecute = self.pgexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        completer = self.completer
        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s> ' % pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode)
        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])
        history_file = self.config['main']['history_file']
        buf = PGBuffer(always_multiline=self.multi_line,
                       completer=completer,
                       history=FileHistory(os.path.expanduser(history_file)),
                       complete_while_typing=Always())

        application = Application(
            style=style_factory(self.syntax_style, self.cli_style),
            layout=layout,
            buffer=buf,
            key_bindings_registry=key_binding_manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            ignore_case=True)
        cli = CommandLineInterface(application=application,
                                   eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in pgexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    start = time()
                    # Run the query.
                    res = pgexecute.run(document.text, self.pgspecial)
                    duration = time() - start
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            click.secho(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        formatted = format_output(
                            title, cur, headers, status, self.table_format,
                            self.pgspecial.expanded_output)
                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    reconnect = True
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        reconnect = click.prompt(
                            'Connection reset. Reconnect (Y/n)',
                            show_default=False,
                            type=bool,
                            default=True)
                        if reconnect:
                            try:
                                pgexecute.connect()
                                click.secho(
                                    'Reconnected!\nTry the command again.',
                                    fg='green')
                            except OperationalError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if self.pgspecial.timing_enabled:
                        print('Command Time: %0.03fs' % duration)
                        print('Format Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_completions()

                # Refresh search_path to set default schema.
                if need_search_path_refresh(document.text):
                    logger.debug('Refreshing search path')
                    completer.set_search_path(pgexecute.search_path())
                    logger.debug('Search path: %r', completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts
Example #14
0
    def run(self):
        labels = self.neo4j.get_labels()
        relationship_types = self.neo4j.get_relationship_types()
        properties = self.neo4j.get_property_keys()

        if self.filename:
            queries = self.filename.read()
            queries = queries.split(";")[:-1]

            for query in queries:
                query += ";"
                query = query.strip()

                print("> " + query)
                self.handle_query(query)
                print()

            return

        click.secho(" ______     __  __     ______     __         __    ", fg="red")
        click.secho("/\  ___\   /\ \_\ \   /\  ___\   /\ \       /\ \   ", fg="yellow")
        click.secho("\ \ \____  \ \____ \  \ \ \____  \ \ \____  \ \ \  ", fg="green")
        click.secho(" \ \_____\  \/\_____\  \ \_____\  \ \_____\  \ \_\ ", fg="blue")
        click.secho("  \/_____/   \/_____/   \/_____/   \/_____/   \/_/ ", fg="magenta")

        print("Cycli version: {}".format(__version__))
        print("Neo4j version: {}".format(".".join(map(str, self.neo4j.neo4j_version))))
        print("Bug reports: https://github.com/nicolewhite/cycli/issues\n")

        completer = CypherCompleter(labels, relationship_types, properties)

        layout = create_prompt_layout(
            lexer=CypherLexer,
            get_prompt_tokens=get_tokens,
            reserve_space_for_menu=8,
        )

        buff = CypherBuffer(
            accept_action=AcceptAction.RETURN_DOCUMENT,
            history=FileHistory(filename=os.path.expanduser('~/.cycli_history')),
            completer=completer,
            complete_while_typing=True,
        )

        application = Application(
            style=PygmentsStyle(CypherStyle),
            buffer=buff,
            layout=layout,
            on_exit=AbortAction.RAISE_EXCEPTION,
            key_bindings_registry=CypherBinder.registry
        )

        cli = CommandLineInterface(application=application, eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()
                query = document.text
                self.handle_query(query)
        except UserWantsOut:
            print("Goodbye!")
        except Exception as e:
            print(e)
Example #15
0
    def run(self):
        neo4j = Neo4j(self.host, self.port, self.username, self.password, self.ssl)
        neo4j.connect()
        self.neo4j = neo4j

        try:
            labels = neo4j.labels()
            relationship_types = neo4j.relationship_types()
            properties = neo4j.properties()

        except Unauthorized:
            print("Unauthorized. See cycli --help for authorization instructions.")
            return

        except SocketError:
            print("Connection refused. Is Neo4j turned on?")
            return

        if self.filename:
            queries = self.filename.read()
            queries = queries.split(";")[:-1]

            for query in queries:
                query += ";"
                query = query.strip()

                print("> " + query)
                self.handle_query(query)
                print()

            return

        click.secho(" ______     __  __     ______     __         __    ", fg="red")
        click.secho("/\  ___\   /\ \_\ \   /\  ___\   /\ \       /\ \   ", fg="yellow")
        click.secho("\ \ \____  \ \____ \  \ \ \____  \ \ \____  \ \ \  ", fg="green")
        click.secho(" \ \_____\  \/\_____\  \ \_____\  \ \_____\  \ \_\ ", fg="blue")
        click.secho("  \/_____/   \/_____/   \/_____/   \/_____/   \/_/ ", fg="magenta")

        print("\nVersion: {}".format(__version__))
        print("Bug reports: https://github.com/nicolewhite/cycli/issues\n")

        completer = CypherCompleter(labels, relationship_types, properties)

        layout = create_default_layout(
            lexer=CypherLexer,
            get_prompt_tokens=get_tokens,
            reserve_space_for_menu=True
        )

        buff = CypherBuffer(
            history=FileHistory(filename=os.path.expanduser('~/.cycli_history')),
            completer=completer,
            complete_while_typing=Always()
        )

        application = Application(
            style=CypherStyle,
            buffer=buff,
            layout=layout,
            on_exit=AbortAction.RAISE_EXCEPTION,
            key_bindings_registry=CypherBinder.registry
        )

        cli = CommandLineInterface(application=application, eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()
                query = document.text
                self.handle_query(query)

        except Exception:
            print("Goodbye!")
Example #16
0
class Saws(object):
    """Encapsulates the Saws CLI.

    Attributes:
        * aws_cli: An instance of prompt_toolkit's CommandLineInterface.
        * config: An instance of Config.
        * config_obj: An instance of ConfigObj, reads from ~/.sawsrc.
        * aws_commands: An instance of AwsCommands
        * commands: A list of commands from data/SOURCES.txt.
        * sub_commands: A list of sub_commands from data/SOURCES.txt.
        * global_options: A list of global_options from data/SOURCES.txt.
        * resource_options: A list of resource_options from data/SOURCES.txt,
            used for syntax coloring.
        * ec2_states: A list of ec2_states from data/SOURCES.txt.
        * completer: An instance of AwsCompleter.
        * key_manager: An instance of KeyManager
        * logger: An instance of SawsLogger.
        * theme: A string representing the lexer theme.
            Currently only 'vim' is supported.
    """

    def __init__(self):
        """Inits Saws.

        Args:
            * None.

        Returns:
            None.
        """
        self.aws_cli = None
        self.key_manager = None
        self.PYGMENTS_CMD = ' | pygmentize -l json'
        self.config = Config()
        self.config_obj = self.config.read_configuration()
        self.theme = self.config_obj['main']['theme']
        self.logger = SawsLogger(__name__,
                                 self.config_obj['main']['log_file'],
                                 self.config_obj['main']['log_level']).logger
        self.aws_commands = AwsCommands()
        self.commands, self.sub_commands, self.global_options, \
            self.resource_options, self.ec2_states \
            = self.aws_commands.generate_all_commands()
        self.completer = AwsCompleter(
            awscli_completer,
            self.commands,
            self.config_obj,
            self.log_exception,
            ec2_states=self.ec2_states,
            fuzzy_match=self.get_fuzzy_match(),
            shortcut_match=self.get_shortcut_match())
        self.create_cli()

    def log_exception(self, e, traceback, echo=False):
        """Logs the exception and traceback to the log file ~/.saws.log.

        Args:
            * e: A Exception that specifies the exception.
            * traceback: A Traceback that specifies the traceback.
            * echo: A boolean that specifies whether to echo the exception
                to the console using click.

        Returns:
            None.
        """
        self.logger.debug('exception: %r.', str(e))
        self.logger.error("traceback: %r", traceback.format_exc())
        if echo:
            click.secho(str(e), fg='red')

    def set_color(self, color):
        """Setter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the color flag.

        Returns:
            None.
        """
        self.config_obj['main']['color_output'] = color

    def get_color(self):
        """Getter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the color flag.
        """
        return self.config_obj['main'].as_bool('color_output')

    def set_fuzzy_match(self, fuzzy):
        """Setter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the fuzzy flag.

        Returns:
            None.
        """
        self.config_obj['main']['fuzzy_match'] = fuzzy
        self.completer.fuzzy_match = fuzzy

    def get_fuzzy_match(self):
        """Getter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the fuzzy flag.
        """
        return self.config_obj['main'].as_bool('fuzzy_match')

    def set_shortcut_match(self, shortcut):
        """Setter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the shortcut flag.

        Returns:
            None.
        """
        self.config_obj['main']['shortcut_match'] = shortcut
        self.completer.shortcut_match = shortcut

    def get_shortcut_match(self):
        """Getter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the shortcut flag.
        """
        return self.config_obj['main'].as_bool('shortcut_match')

    def refresh_resources(self):
        """Convenience function to refresh resources for completion.

        Used by prompt_toolkit's KeyBindingManager.

        Args:
            * None.

        Returns:
            None.
        """
        self.completer.refresh_resources(force_refresh=True)

    def handle_docs(self, text=None, from_fkey=False):
        """Displays contextual web docs for `F9` or the `docs` command.

        Displays the web docs specific to the currently entered:

        * (optional) command
        * (optional) subcommand

        If no command or subcommand is present, the docs index page is shown.

        Docs are only displayed if:

        * from_fkey is True
        * from_fkey is False and `docs` is found in text

        Args:
            * text: A string representing the input command text.
            * from_fkey: A boolean representing whether this function is
                being executed from an `F9` key press.

        Returns:
            A boolean representing whether the web docs were shown.
        """
        base_url = 'http://docs.aws.amazon.com/cli/latest/reference/'
        index_html = 'index.html'
        if text is None:
            text = self.aws_cli.current_buffer.document.text
        # If the user hit the F9 key, append 'docs' to the text
        if from_fkey:
            text = text.strip() + ' ' + AwsCommands.AWS_DOCS
        tokens = text.split()
        if len(tokens) > 2 and tokens[-1] == AwsCommands.AWS_DOCS:
            prev_word = tokens[-2]
            # If we have a command, build the url
            if prev_word in self.commands:
                prev_word = prev_word + '/'
                url = base_url + prev_word + index_html
                webbrowser.open(url)
                return True
            # if we have a command and subcommand, build the url
            elif prev_word in self.sub_commands:
                command_url = tokens[-3] + '/'
                sub_command_url = tokens[-2] + '.html'
                url = base_url + command_url + sub_command_url
                webbrowser.open(url)
                return True
            webbrowser.open(base_url + index_html)
        # If we still haven't opened the help doc at this point and the
        # user hit the F9 key or typed docs, just open the main docs index
        if from_fkey or AwsCommands.AWS_DOCS in tokens:
            webbrowser.open(base_url + index_html)
            return True
        return False

    def handle_cd(self, text):
        """Handles a `cd` shell command by calling python's os.chdir.

        Simply passing in the `cd` command to subprocess.call doesn't work.
        Note: Changing the directory within Saws will only be in effect while
        running Saws.  Exiting the program will return you to the directory
        you were in prior to running Saws.

        Attributes:
            * text: A string representing the input command text.

        Returns:
            A boolean representing a `cd` command was found and handled.
        """
        CD_CMD = 'cd'
        stripped_text = text.strip()
        if stripped_text.startswith(CD_CMD):
            directory = ''
            if stripped_text == CD_CMD:
                # Treat `cd` as a change to the root directory.
                # os.path.expanduser does this in a cross platform manner.
                directory = os.path.expanduser('~')
            else:
                tokens = text.split(CD_CMD + ' ')
                directory = tokens[-1]
            try:
                os.chdir(directory)
            except OSError as e:
                self.log_exception(e, traceback, echo=True)
            return True
        return False

    def colorize_output(self, text):
        """Highlights output with pygments.

        Only highlights the output if all of the following conditions are True:

        * The color option is enabled
        * The text does not contain the `configure` command
        * The text does not contain the `help` command, which already does
            output highlighting

        Args:
            * text: A string that represents the input command text.

        Returns:
            A string that represents:
                * The original command text if no highlighting was performed.
                * The pygments highlighted command text otherwise.
        """
        if not self.get_color():
            return text
        stripped_text = text.strip()
        excludes = [AwsCommands.AWS_CONFIGURE,
                    AwsCommands.AWS_HELP,
                    '|']
        if not any(substring in stripped_text for substring in excludes):
            return text.strip() + self.PYGMENTS_CMD
        else:
            return text

    def process_command(self, text):
        """Processes the input command, called by the cli event loop

        Args:
            * text: A string that represents the input command text.

        Returns:
            None.
        """
        if AwsCommands.AWS_COMMAND in text:
            text = self.completer.replace_shortcut(text)
            if self.handle_docs(text):
                return
            text = self.colorize_output(text)
        try:
            if not self.handle_cd(text):
                # Pass the command onto the shell so aws-cli can execute it
                subprocess.call(text, shell=True)
            print('')
        except Exception as e:
            self.log_exception(e, traceback, echo=True)

    def create_cli(self):
        """Creates the prompt_toolkit's CommandLineInterface.

        Long description.

        Args:
            * None.

        Returns:
            None.
        """
        history = FileHistory(os.path.expanduser('~/.saws-history'))
        toolbar = Toolbar(self.get_color,
                          self.get_fuzzy_match,
                          self.get_shortcut_match)
        layout = create_default_layout(
            message='saws> ',
            reserve_space_for_menu=True,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar.handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )
        cli_buffer = Buffer(
            history=history,
            completer=self.completer,
            complete_while_typing=Always())
        self.key_manager = KeyManager(
            self.set_color,
            self.get_color,
            self.set_fuzzy_match,
            self.get_fuzzy_match,
            self.set_shortcut_match,
            self.get_shortcut_match,
            self.refresh_resources,
            self.handle_docs)
        style_factory = StyleFactory(self.theme)
        application = Application(
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            ignore_case=True)
        eventloop = create_eventloop()
        self.aws_cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def run_cli(self):
        """Runs the main loop.

        Args:
            * None.

        Returns:
            None.
        """
        print('Version:', __version__)
        print('Theme:', self.theme)
        while True:
            document = self.aws_cli.run()
            self.process_command(document.text)
Example #17
0
class WharfeeCli(object):
    """
    The CLI implementation.
    """

    dcli = None
    keyword_completer = None
    handler = None
    saved_less_opts = None
    config = None
    config_template = 'wharfeerc'
    config_name = '~/.wharfeerc'

    def __init__(self, no_completion=False):
        """
        Initialize class members.
        Should read the config here at some point.
        """

        self.config = self.read_configuration()
        self.theme = self.config['main']['theme']

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']
        self.logger = create_logger(__name__, log_file, log_level)

        # set_completer_options refreshes all by default
        self.handler = DockerClient(
            self.config['main'].as_int('client_timeout'),
            self.clear,
            self.refresh_completions_force,
            self.logger)

        self.completer = DockerCompleter(
            long_option_names=self.get_long_options(),
            fuzzy=self.get_fuzzy_match())
        self.set_completer_options()
        self.completer.set_enabled(not no_completion)
        self.saved_less_opts = self.set_less_opts()

    def read_configuration(self):
        """

        :return:
        """
        default_config = os.path.join(
            self.get_package_path(), self.config_template)
        write_default_config(default_config, self.config_name)
        return read_config(self.config_name, default_config)

    def get_package_path(self):
        """
        Find out pakage root path.
        :return: string: path
        """
        from wharfee import __file__ as package_root
        return os.path.dirname(package_root)

    def set_less_opts(self):
        """
        Set the "less" options and save the old settings.

        What we're setting:
          -F:
            --quit-if-one-screen: Quit if entire file fits on first screen.
          -R:
            --raw-control-chars: Output "raw" control characters.
          -X:
            --no-init: Don't use termcap keypad init/deinit strings.
            --no-keypad: Don't use termcap init/deinit strings.
            This also disables launching "less" in an alternate screen.

        :return: string with old options
        """
        opts = os.environ.get('LESS', '')
        os.environ['LESS'] = '-RXF'
        return opts

    def revert_less_opts(self):
        """
        Restore the previous "less" options.
        """
        os.environ['LESS'] = self.saved_less_opts

    def write_config_file(self):
        """
        Write config file on exit.
        """
        self.config.write()

    def clear(self):
        """
        Clear the screen.
        """
        click.clear()

    def set_completer_options(self, cons=True, runs=True, imgs=True, vols=True):
        """
        Set image and container names in Completer.
        Re-read if needed after a command.
        :param cons: boolean: need to refresh containers
        :param runs: boolean: need to refresh running containers
        :param imgs: boolean: need to refresh images
        :param vols: boolean: need to refresh volumes
        """

        if cons:
            cs = self.handler.containers(all=True)
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                containers = [name for c in cs for name in c['Names']]
                self.completer.set_containers(containers)

        if runs:
            cs = self.handler.containers()
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                running = [name for c in cs for name in c['Names']]
                self.completer.set_running(running)

        if imgs:
            def format_tagged(tagname, img_id):
                if tagname == '<none>:<none>':
                    return img_id[:11]
                return tagname

            def parse_image_name(tag, img_id):
                if ':' in tag:
                    result = tag.split(':', 2)[0]
                else:
                    result = tag
                if result == '<none>':
                    result = img_id[:11]
                return result

            ims = self.handler.images()
            if ims and len(ims) > 0 and isinstance(ims[0], dict):
                images = set([])
                tagged = set([])
                for im in ims:
                    repo_tag = '{0}:{1}'.format(im['Repository'], im['Tag'])
                    images.add(parse_image_name(repo_tag, im['Id']))
                    tagged.add(format_tagged(repo_tag, im['Id']))
                self.completer.set_images(images)
                self.completer.set_tagged(tagged)

        if vols:
            vs = self.handler.volume_ls(quiet=True)
            self.completer.set_volumes(vs)

    def set_fuzzy_match(self, is_fuzzy):
        """
        Setter for fuzzy matching mode
        :param is_fuzzy: boolean
        """
        self.config['main']['fuzzy_match'] = is_fuzzy
        self.completer.set_fuzzy_match(is_fuzzy)

    def get_fuzzy_match(self):
        """
        Getter for fuzzy matching mode
        :return: boolean
        """
        return self.config['main'].as_bool('fuzzy_match')

    def set_long_options(self, is_long):
        """
        Setter for long option names.
        :param is_long: boolean
        """
        self.config['main']['suggest_long_option_names'] = is_long
        self.completer.set_long_options(is_long)

    def get_long_options(self):
        """
        Getter for long option names.
        :return: boolean
        """
        return self.config['main'].as_bool('suggest_long_option_names')

    def refresh_completions_force(self):
        """Force refresh and make it visible."""
        self.set_completer_options()
        click.echo('Refreshed completions.')

    def refresh_completions(self):
        """
        After processing the command, refresh the lists of
        containers and images as needed
        """
        self.set_completer_options(self.handler.is_refresh_containers,
                                   self.handler.is_refresh_running,
                                   self.handler.is_refresh_images,
                                   self.handler.is_refresh_volumes)

    def run_cli(self):
        """
        Run the main loop
        """
        print('Version:', __version__)
        print('Home: http://wharfee.com')

        history = FileHistory(os.path.expanduser('~/.wharfee-history'))
        toolbar_handler = create_toolbar_handler(self.get_long_options, self.get_fuzzy_match)

        layout = create_prompt_layout(
            message='wharfee> ',
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar_handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )

        cli_buffer = Buffer(
            history=history,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)

        manager = get_key_manager(
            self.set_long_options,
            self.get_long_options,
            self.set_fuzzy_match,
            self.get_fuzzy_match)

        application = Application(
            style=style_factory(self.theme),
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)

        eventloop = create_eventloop()

        self.dcli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

        while True:
            try:
                document = self.dcli.run(True)
                self.handler.handle_input(document.text)

                if isinstance(self.handler.output, GeneratorType):
                    output_stream(self.handler.command,
                                  self.handler.output,
                                  self.handler.log)

                elif self.handler.output is not None:
                    lines = format_data(
                        self.handler.command,
                        self.handler.output)
                    click.echo_via_pager('\n'.join(lines))

                if self.handler.after:
                    for line in self.handler.after():
                        click.echo(line)

                if self.handler.exception:
                    # This was handled, just log it.
                    self.logger.warning('An error was handled: %r',
                                        self.handler.exception)

                self.refresh_completions()

            except OptionError as ex:
                self.logger.debug('Error: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(ex.msg, fg='red')

            except KeyboardInterrupt:
                # user pressed Ctrl + C
                if self.handler.after:
                    click.echo('')
                    for line in self.handler.after():
                        click.echo(line)

                self.refresh_completions()

            except DockerPermissionException as ex:
                self.logger.debug('Permission exception: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(str(ex), fg='red')

            except EOFError:
                # exit out of the CLI
                break

            # TODO: uncomment for release
            except Exception as ex:
                self.logger.debug('Exception: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(str(ex), fg='red')

        self.revert_less_opts()
        self.write_config_file()
        print('Goodbye!')
Example #18
0
class WharfeeCli(object):
    """
    The CLI implementation.
    """

    dcli = None
    keyword_completer = None
    handler = None
    saved_less_opts = None
    config = None
    config_template = "wharfeerc"
    config_name = "~/.wharfeerc"

    def __init__(self):
        """
        Initialize class members.
        Should read the config here at some point.
        """

        self.config = self.read_configuration()
        self.theme = self.config["main"]["theme"]

        log_file = self.config["main"]["log_file"]
        log_level = self.config["main"]["log_level"]
        self.logger = create_logger(__name__, log_file, log_level)

        # set_completer_options refreshes all by default
        self.handler = DockerClient(
            self.config["main"].as_int("client_timeout"), self.clear, self.set_completer_options
        )

        self.completer = DockerCompleter(long_option_names=self.get_long_options(), fuzzy=self.get_fuzzy_match())
        self.set_completer_options()
        self.saved_less_opts = self.set_less_opts()

    def read_configuration(self):
        """

        :return:
        """
        default_config = os.path.join(self.get_package_path(), self.config_template)
        write_default_config(default_config, self.config_name)
        return read_config(self.config_name, default_config)

    def get_package_path(self):
        """
        Find out pakage root path.
        :return: string: path
        """
        from wharfee import __file__ as package_root

        return os.path.dirname(package_root)

    def set_less_opts(self):
        """
        Set the "less" options and save the old settings.

        What we're setting:
          -F:
            --quit-if-one-screen: Quit if entire file fits on first screen.
          -R:
            --raw-control-chars: Output "raw" control characters.
          -X:
            --no-init: Don't use termcap keypad init/deinit strings.
            --no-keypad: Don't use termcap init/deinit strings.
            This also disables launching "less" in an alternate screen.

        :return: string with old options
        """
        opts = os.environ.get("LESS", "")
        os.environ["LESS"] = "-RXF"
        return opts

    def revert_less_opts(self):
        """
        Restore the previous "less" options.
        """
        os.environ["LESS"] = self.saved_less_opts

    def write_config_file(self):
        """
        Write config file on exit.
        """
        self.config.write()

    def clear(self):
        """
        Clear the screen.
        """
        click.clear()

    def set_completer_options(self, cons=True, runs=True, imgs=True):
        """
        Set image and container names in Completer.
        Re-read if needed after a command.
        :param cons: boolean: need to refresh containers
        :param runs: boolean: need to refresh running containers
        :param imgs: boolean: need to refresh images
        """

        if cons:
            cs = self.handler.containers(all=True)
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                containers = [name for c in cs for name in c["Names"]]
                self.completer.set_containers(containers)

        if runs:
            cs = self.handler.containers()
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                running = [name for c in cs for name in c["Names"]]
                self.completer.set_running(running)

        if imgs:

            def format_tagged(tagname, img_id):
                if tagname == "<none>:<none>":
                    return img_id[:11]
                return tagname

            def parse_image_name(tag, img_id):
                if ":" in tag:
                    result = tag.split(":", 2)[0]
                else:
                    result = tag
                if result == "<none>":
                    result = img_id[:11]
                return result

            ims = self.handler.images()
            if ims and len(ims) > 0 and isinstance(ims[0], dict):
                images = set([])
                tagged = set([])
                for im in ims:
                    repo_tag = "{0}:{1}".format(im["Repository"], im["Tag"])
                    images.add(parse_image_name(repo_tag, im["Id"]))
                    tagged.add(format_tagged(repo_tag, im["Id"]))
                self.completer.set_images(images)
                self.completer.set_tagged(tagged)

    def set_fuzzy_match(self, is_fuzzy):
        """
        Setter for fuzzy matching mode
        :param is_fuzzy: boolean
        """
        self.config["main"]["fuzzy_match"] = is_fuzzy
        self.completer.set_fuzzy_match(is_fuzzy)

    def get_fuzzy_match(self):
        """
        Getter for fuzzy matching mode
        :return: boolean
        """
        return self.config["main"].as_bool("fuzzy_match")

    def set_long_options(self, is_long):
        """
        Setter for long option names.
        :param is_long: boolean
        """
        self.config["main"]["suggest_long_option_names"] = is_long
        self.completer.set_long_options(is_long)

    def get_long_options(self):
        """
        Getter for long option names.
        :return: boolean
        """
        return self.config["main"].as_bool("suggest_long_option_names")

    def refresh_completions(self):
        """
        After processing the command, refresh the lists of
        containers and images as needed
        """
        self.set_completer_options(
            self.handler.is_refresh_containers, self.handler.is_refresh_running, self.handler.is_refresh_images
        )

    def run_cli(self):
        """
        Run the main loop
        """
        print("Version:", __version__)
        print("Home: http://wharfee.com")

        history = FileHistory(os.path.expanduser("~/.wharfee-history"))
        toolbar_handler = create_toolbar_handler(self.get_long_options, self.get_fuzzy_match)

        layout = create_default_layout(
            message="wharfee> ",
            reserve_space_for_menu=True,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar_handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(chars="[](){}"),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
                )
            ],
        )

        cli_buffer = Buffer(history=history, completer=self.completer, complete_while_typing=Always())

        manager = get_key_manager(
            self.set_long_options, self.get_long_options, self.set_fuzzy_match, self.get_fuzzy_match
        )

        application = Application(
            style=style_factory(self.theme),
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            ignore_case=True,
        )

        eventloop = create_eventloop()

        self.dcli = CommandLineInterface(application=application, eventloop=eventloop)

        while True:
            try:
                document = self.dcli.run()
                self.handler.handle_input(document.text)

                if isinstance(self.handler.output, GeneratorType):
                    output_stream(self.handler.command, self.handler.output, self.handler.logs)

                elif self.handler.output is not None:
                    lines = format_data(self.handler.command, self.handler.output)
                    click.echo_via_pager("\n".join(lines))

                if self.handler.after:
                    for line in self.handler.after():
                        click.echo(line)

                self.refresh_completions()

            except OptionError as ex:
                self.logger.debug("Error: %r.", ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(ex.msg, fg="red")

            except KeyboardInterrupt:
                # user pressed Ctrl + C
                if self.handler.after:
                    click.echo("")
                    for line in self.handler.after():
                        click.echo(line)

                self.refresh_completions()

            except DockerPermissionException as ex:
                self.logger.debug("Permission exception: %r.", ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(ex.message, fg="red")

            except EOFError:
                # exit out of the CLI
                break

            # TODO: uncomment for release
            except Exception as ex:
                self.logger.debug("Exception: %r.", ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho("{0}".format(ex), fg="red")
                break

        self.revert_less_opts()
        self.write_config_file()
        print("Goodbye!")
Example #19
0
class Prompt:
    def __init__(self, run_name, state_obj):
        self.run_name = run_name
        self.state_obj = state_obj
        self.cli = None
        self.q = queue.Queue()
        self.thread = threading.Thread(target=self.run)

    def start(self):
        self.thread.start()

    def stop(self):
        if self.cli:
            self.cli.exit()
        self.thread.join()

    def get_bottom_toolbar_tokens(self, cli):
        return [(Token.Toolbar, 'Run '), (Token.Name, self.run_name),
                (Token.Toolbar, ' in progress.')]

    def get_prompt_tokens(self, cli):
        return [(Token.Prompt, '> ')]

    def run(self):
        style = style_from_dict({
            Token.Prompt: 'bold',
            Token.Toolbar: '#ccc bg:#333',
            Token.Name: '#fff bold bg:#333',
        })

        history = InMemoryHistory()
        eventloop = create_eventloop()
        app = create_prompt_application(
            history=history,
            style=style,
            get_bottom_toolbar_tokens=self.get_bottom_toolbar_tokens,
            get_prompt_tokens=self.get_prompt_tokens)
        self.cli = CommandLineInterface(app, eventloop)

        with self.cli.patch_stdout_context(raw=True):
            while True:
                try:
                    self.cli.run()
                    doc = self.cli.return_value()
                    if doc is None:
                        return
                    cmd = shlex.split(doc.text)
                    app.buffer.reset(append_to_history=True)

                    if not cmd:
                        continue
                    elif cmd[0] in ('exit', 'quit'):
                        self.q.put(Exit())
                        return
                    elif cmd[0] == 'help':
                        print('Help text forthcoming.')
                    elif cmd[0] == 'skip':
                        self.q.put(Skip())
                    elif cmd[0] == 'set':
                        self.q.put(
                            Set(cmd[1], ast.literal_eval(' '.join(cmd[2:]))))
                    else:
                        print('Unknown command. Try \'help\'.')
                except KeyboardInterrupt:
                    continue
                except EOFError:
                    self.q.put(Exit())
                    return
                except Exception as err:
                    print(err)
                    self.q.put(Exit())
                    return
Example #20
0
class MyCli(object):

    default_prompt = "\\t \\u@\\h:\\d> "
    defaults_suffix = None

    # In order of being loaded. Files lower in list override earlier ones.
    cnf_files = ["/etc/my.cnf", "/etc/mysql/my.cnf", "/usr/local/etc/my.cnf", "~/.my.cnf"]

    system_config_files = ["/etc/myclirc"]

    default_config_file = os.path.join(PACKAGE_ROOT, "myclirc")
    user_config_file = "~/.myclirc"

    def __init__(
        self,
        sqlexecute=None,
        prompt=None,
        logfile=None,
        defaults_suffix=None,
        defaults_file=None,
        login_path=None,
        auto_vertical_output=False,
        warn=None,
    ):
        self.sqlexecute = sqlexecute
        self.logfile = logfile
        self.defaults_suffix = defaults_suffix
        self.login_path = login_path
        self.auto_vertical_output = auto_vertical_output

        # self.cnf_files is a class variable that stores the list of mysql
        # config files to read in at launch.
        # If defaults_file is specified then override the class variable with
        # defaults_file.
        if defaults_file:
            self.cnf_files = [defaults_file]

        # Load config.
        config_files = [self.default_config_file] + self.system_config_files + [self.user_config_file]
        c = self.config = read_config_files(config_files)
        self.multi_line = c["main"].as_bool("multi_line")
        self.key_bindings = c["main"]["key_bindings"]
        special.set_timing_enabled(c["main"].as_bool("timing"))
        self.table_format = c["main"]["table_format"]
        self.syntax_style = c["main"]["syntax_style"]
        self.cli_style = c["colors"]
        self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
        c_dest_warning = c["main"].as_bool("destructive_warning")
        self.destructive_warning = c_dest_warning if warn is None else warn

        # Write user config if system config wasn't the last config loaded.
        if c.filename not in self.system_config_files:
            write_default_config(self.default_config_file, self.user_config_file)

        # audit log
        if self.logfile is None and "audit_log" in c["main"]:
            try:
                self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a")
            except (IOError, OSError) as e:
                self.output(
                    "Error: Unable to open the audit log file. Your queries will not be logged.", err=True, fg="red"
                )
                self.logfile = False

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        prompt_cnf = self.read_my_cnf_files(self.cnf_files, ["prompt"])["prompt"]
        self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt

        self.query_history = []

        # Initialize completer.
        smart_completion = c["main"].as_bool("smart_completion")
        self.completer = SQLCompleter(smart_completion)
        self._completer_lock = threading.Lock()

        # Register custom special commands.
        self.register_special_commands()

        # Load .mylogin.cnf if it exists.
        mylogin_cnf_path = get_mylogin_cnf_path()
        if mylogin_cnf_path:
            try:
                mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
                if mylogin_cnf_path and mylogin_cnf:
                    # .mylogin.cnf gets read last, even if defaults_file is specified.
                    self.cnf_files.append(mylogin_cnf)
                elif mylogin_cnf_path and not mylogin_cnf:
                    # There was an error reading the login path file.
                    print("Error: Unable to read login path file.")
            except CryptoError:
                click.secho("Warning: .mylogin.cnf was not read: pycrypto " "module is not available.")

        self.cli = None

    def register_special_commands(self):
        special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=("\\u",))
        special.register_special_command(
            self.change_db,
            "connect",
            "\\r",
            "Reconnect to the database. Optional database argument.",
            aliases=("\\r",),
            case_sensitive=True,
        )
        special.register_special_command(
            self.refresh_completions, "rehash", "\\#", "Refresh auto-completions.", arg_type=NO_QUERY, aliases=("\\#",)
        )
        special.register_special_command(
            self.change_table_format, "tableformat", "\\T", "Change Table Type.", aliases=("\\T",), case_sensitive=True
        )
        special.register_special_command(
            self.execute_from_file, "source", "\\. filename", "Execute commands from file.", aliases=("\\.",)
        )
        special.register_special_command(
            self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=("\\R",), case_sensitive=True
        )

    def change_table_format(self, arg, **_):
        if not arg in table_formats():
            msg = "Table type %s not yet implemented.  Allowed types:" % arg
            for table_type in table_formats():
                msg += "\n\t%s" % table_type
            yield (None, None, None, msg)
        else:
            self.table_format = arg
            yield (None, None, None, "Changed table Type to %s" % self.table_format)

    def change_db(self, arg, **_):
        if arg is None:
            self.sqlexecute.connect()
        else:
            self.sqlexecute.connect(database=arg)

        yield (
            None,
            None,
            None,
            'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user),
        )

    def execute_from_file(self, arg, **_):
        if not arg:
            message = "Missing required argument, filename."
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(arg), encoding="utf-8") as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        if self.destructive_warning and confirm_destructive_query(query) is False:
            message = "Wise choice. Command execution stopped."
            return [(None, None, None, message)]

        return self.sqlexecute.run(query)

    def change_prompt_format(self, arg, **_):
        """
        Change the prompt format.
        """
        if not arg:
            message = "Missing required argument, format."
            return [(None, None, None, message)]

        self.prompt_format = self.get_prompt(arg)
        return [(None, None, None, "Changed prompt format to %s" % arg)]

    def initialize_logging(self):

        log_file = self.config["main"]["log_file"]
        log_level = self.config["main"]["log_level"]

        level_map = {
            "CRITICAL": logging.CRITICAL,
            "ERROR": logging.ERROR,
            "WARNING": logging.WARNING,
            "INFO": logging.INFO,
            "DEBUG": logging.DEBUG,
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            "%(asctime)s (%(process)d/%(threadName)s) " "%(name)s %(levelname)s - %(message)s"
        )

        handler.setFormatter(formatter)

        root_logger = logging.getLogger("mycli")
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        # Only capture warnings on Python 2.7 and later.
        try:
            logging.captureWarnings(True)
        except AttributeError:
            pass

        root_logger.debug("Initializing mycli logging.")
        root_logger.debug("Log file %r.", log_file)

    def connect_uri(self, uri, local_infile=None, ssl=None):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.username, uri.password, uri.hostname, uri.port, local_infile=local_infile, ssl=ssl)

    def read_my_cnf_files(self, files, keys):
        """
        Reads a list of config files and merges them. The last one will win.
        :param files: list of files to read
        :param keys: list of keys to retrieve
        :returns: tuple, with None for missing keys.
        """
        cnf = read_config_files(files)

        sections = ["client"]
        if self.login_path and self.login_path != "client":
            sections.append(self.login_path)

        if self.defaults_suffix:
            sections.extend([sect + self.defaults_suffix for sect in sections])

        def get(key):
            result = None
            for sect in cnf:
                if sect in sections and key in cnf[sect]:
                    result = cnf[sect][key]
            return result

        return dict([(x, get(x)) for x in keys])

    def merge_ssl_with_cnf(self, ssl, cnf):
        """Merge SSL configuration dict with cnf dict"""

        merged = {}
        merged.update(ssl)
        prefix = "ssl-"
        for k, v in cnf.items():
            # skip unrelated options
            if not k.startswith(prefix):
                continue
            if v is None:
                continue
            # special case because PyMySQL argument is significantly different
            # from commandline
            if k == "ssl-verify-server-cert":
                merged["check_hostname"] = v
            else:
                # use argument name just strip "ssl-" prefix
                arg = k[len(prefix) :]
                merged[arg] = v

        return merged

    def connect(
        self, database="", user="", passwd="", host="", port="", socket="", charset="", local_infile="", ssl=""
    ):

        cnf = {
            "database": None,
            "user": None,
            "password": None,
            "host": None,
            "port": None,
            "socket": None,
            "default-character-set": None,
            "local-infile": None,
            "loose-local-infile": None,
            "ssl-ca": None,
            "ssl-cert": None,
            "ssl-key": None,
            "ssl-cipher": None,
            "ssl-verify-serer-cert": None,
        }

        cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())

        # Fall back to config values only if user did not specify a value.

        database = database or cnf["database"]
        if port or host:
            socket = ""
        else:
            socket = socket or cnf["socket"]
        user = user or cnf["user"] or os.getenv("USER")
        host = host or cnf["host"] or "localhost"
        port = port or cnf["port"] or 3306
        ssl = ssl or {}

        try:
            port = int(port)
        except ValueError as e:
            self.output("Error: Invalid port number: '{0}'.".format(port), err=True, fg="red")
            exit(1)

        passwd = passwd or cnf["password"]
        charset = charset or cnf["default-character-set"] or "utf8"

        # Favor whichever local_infile option is set.
        for local_infile_option in (local_infile, cnf["local-infile"], cnf["loose-local-infile"], False):
            try:
                local_infile = str_to_bool(local_infile_option)
                break
            except (TypeError, ValueError):
                pass

        ssl = self.merge_ssl_with_cnf(ssl, cnf)
        # prune lone check_hostname=False
        if not any(v for v in ssl.values()):
            ssl = None

        # Connect to the database.

        try:
            try:
                sqlexecute = SQLExecute(database, user, passwd, host, port, socket, charset, local_infile, ssl)
            except OperationalError as e:
                if "Access denied for user" in e.args[1]:
                    passwd = click.prompt("Password", hide_input=True, show_default=False, type=str)
                    sqlexecute = SQLExecute(database, user, passwd, host, port, socket, charset, local_infile, ssl)
                else:
                    raise e
        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug("Database connection failed: %r.", e)
            self.logger.error("traceback: %r", traceback.format_exc())
            self.output(str(e), err=True, fg="red")
            exit(1)

        self.sqlexecute = sqlexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename, sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        sqlexecute = self.sqlexecute
        logger = self.logger
        self.configure_pager()

        self.refresh_completions()

        project_root = os.path.dirname(PACKAGE_ROOT)
        author_file = os.path.join(project_root, "AUTHORS")
        sponsor_file = os.path.join(project_root, "SPONSORS")

        key_binding_manager = mycli_bindings()

        print("Version:", __version__)
        print("Chat: https://gitter.im/dbcli/mycli")
        print("Mail: https://groups.google.com/forum/#!forum/mycli-users")
        print("Home: http://mycli.net")
        print("Thanks to the contributor -", thanks_picker([author_file, sponsor_file]))

        def prompt_tokens(cli):
            return [(Token.Prompt, self.get_prompt(self.prompt_format))]

        def get_continuation_tokens(cli, width):
            return [(Token.Continuation, " " * (width - 3) + "-> ")]

        get_toolbar_tokens = create_toolbar_tokens_func(self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(
            lexer=MyCliLexer,
            multiline=True,
            get_prompt_tokens=prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(chars="[](){}"),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
                )
            ],
        )
        with self._completer_lock:
            buf = CLIBuffer(
                always_multiline=self.multi_line,
                completer=self.completer,
                history=FileHistory(os.path.expanduser("~/.mycli-history")),
                complete_while_typing=Always(),
                accept_action=AcceptAction.RETURN_DOCUMENT,
            )

            if self.key_bindings == "vi":
                editing_mode = EditingMode.VI
            else:
                editing_mode = EditingMode.EMACS

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                editing_mode=editing_mode,
                ignore_case=True,
            )
            self.cli = CommandLineInterface(application=application, eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run(reset_current_buffer=True)

                special.set_expanded_output(False)

                # The reason we check here instead of inside the sqlexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the
                # sqlexecute.run() statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg="red")
                    continue
                if self.destructive_warning:
                    destroy = confirm_destructive_query(document.text)
                    if destroy is None:
                        pass  # Query was not destructive. Nothing to do here.
                    elif destroy is True:
                        self.output("Your call!")
                    else:
                        self.output("Wise choice!")
                        continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug("sql: %r", document.text)

                    if self.logfile:
                        self.logfile.write("\n# %s\n" % datetime.now())
                        self.logfile.write(document.text)
                        self.logfile.write("\n")

                    successful = False
                    start = time()
                    res = sqlexecute.run(document.text)
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if is_select(status) and cur and cur.rowcount > threshold:
                            self.output("The result set has more than %s rows." % threshold, fg="red")
                            if not click.confirm("Do you want to continue?"):
                                self.output("Aborted!", err=True, fg="red")
                                break

                        if self.auto_vertical_output:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(
                            title, cur, headers, status, self.table_format, special.is_expanded_output(), max_width
                        )

                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)
                except UnicodeDecodeError as e:
                    import pymysql

                    if pymysql.VERSION < (0, 6, 7):
                        message = (
                            "You are running an older version of pymysql.\n"
                            "Please upgrade to 0.6.7 or above to view binary data.\n"
                            "Try 'pip install -U pymysql'."
                        )
                        self.output(message)
                    else:
                        raise e
                except KeyboardInterrupt:
                    # Restart connection to the database
                    sqlexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    self.output("cancelled query", err=True, fg="red")
                except NotImplementedError:
                    self.output("Not Yet Implemented.", fg="yellow")
                except OperationalError as e:
                    logger.debug("Exception: %r", e)
                    reconnect = True
                    if e.args[0] in (2003, 2006, 2013):
                        reconnect = click.prompt(
                            "Connection reset. Reconnect (Y/n)", show_default=False, type=bool, default=True
                        )
                        if reconnect:
                            logger.debug("Attempting to reconnect.")
                            try:
                                sqlexecute.connect()
                                logger.debug("Reconnected successfully.")
                                self.output("Reconnected!\nTry the command again.", fg="green")
                            except OperationalError as e:
                                logger.debug("Reconnect failed. e: %r", e)
                                self.output(str(e), err=True, fg="red")
                                continue  # If reconnection failed, don't proceed further.
                        else:  # If user chooses not to reconnect, don't proceed further.
                            continue
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        self.output(str(e), err=True, fg="red")
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg="red")
                else:
                    try:
                        if special.is_pager_enabled():
                            self.output_via_pager("\n".join(output))
                        else:
                            self.output("\n".join(output))
                    except KeyboardInterrupt:
                        pass
                    if special.is_timing_enabled():
                        self.output("Time: %0.03fs" % total)

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(reset=need_completion_reset(document.text))
                finally:
                    if self.logfile is False:
                        self.output("Warning: This query was not logged.", err=True, fg="red")
                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            self.output("Goodbye!")

    def output(self, text, **kwargs):
        if self.logfile:
            self.logfile.write(utf8tounicode(text))
            self.logfile.write("\n")
        click.secho(text, **kwargs)

    def output_via_pager(self, text):
        if self.logfile:
            self.logfile.write(text)
            self.logfile.write("\n")
        click.echo_via_pager(text)

    def configure_pager(self):
        # Provide sane defaults for less.
        os.environ["LESS"] = "-RXF"

        cnf = self.read_my_cnf_files(self.cnf_files, ["pager", "skip-pager"])
        if cnf["pager"]:
            special.set_pager(cnf["pager"])
        if cnf["skip-pager"]:
            special.disable_pager()

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.sqlexecute, self._on_completions_refreshed)

        return [(None, None, None, "Auto-completion refresh started in the background.")]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When mycli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        sqlexecute = self.sqlexecute
        string = string.replace("\\u", sqlexecute.user or "(none)")
        string = string.replace("\\h", sqlexecute.host or "(none)")
        string = string.replace("\\d", sqlexecute.dbname or "(none)")
        string = string.replace("\\t", sqlexecute.server_type()[0] or "mycli")
        string = string.replace("\\n", "\n")
        return string
Example #21
0
class VCli(object):
    def __init__(self, vexecute=None, vclirc_file=None):
        self.vexecute = vexecute

        from vcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'vclirc')
        write_default_config(default_config, vclirc_file)

        self.vspecial = VSpecial()

        # Load config.
        c = self.config = load_config(vclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.vspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = VCompleter(smart_completion, vspecial=self.vspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):
        self.vspecial.register(self.change_db,
                               '\\c',
                               '\\c[onnect] [DBNAME]',
                               'Connect to a new database',
                               aliases=('use', '\\connect', 'USE'))
        self.vspecial.register(self.refresh_completions,
                               '\\#',
                               '\\#',
                               'Refresh auto-completions',
                               arg_type=NO_QUERY)
        self.vspecial.register(self.refresh_completions,
                               '\\refresh',
                               '\\refresh',
                               'Refresh auto-completions',
                               arg_type=NO_QUERY)

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.vexecute.connect(database=db)
        else:
            self.vexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.vexecute.dbname, self.vexecute.user), True)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {
            'CRITICAL': logging.CRITICAL,
            'ERROR': logging.ERROR,
            'WARNING': logging.WARNING,
            'INFO': logging.INFO,
            'DEBUG': logging.DEBUG
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('vcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing vcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        host = uri.hostname or 'localhost'
        user = uri.username or getpass.getuser()
        port = uri.port or 5433
        password = uri.password or ''
        self.connect(database, host, user, port, password)

    def connect(self, database, host, user, port, passwd):
        # Connect to the database
        try:
            self.vexecute = VExecute(database, user, passwd, host, port)
        except errors.DatabaseError as e:  # Connection can fail
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            error_msg = str(e) or type(e).__name__
            click.secho(error_msg, err=True, fg='red')
            exit(1)

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                        sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql,
                                                   cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        vexecute = self.vexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        completer = self.completer
        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = vcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        click.secho('Version: %s' % __version__)

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s=> ' % vexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing)
        input_processors = [
            # Highlight matching brackets while editing.
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ]
        layout = create_prompt_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=8,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=input_processors)
        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = VBuffer(always_multiline=self.multi_line,
                          completer=self.completer,
                          history=FileHistory(
                              os.path.expanduser(history_file)),
                          complete_while_typing=Always(),
                          accept_action=AcceptAction.RETURN_DOCUMENT)

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                on_abort=AbortAction.RETRY,
                ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                            eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the vexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the vexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in vexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    start = time()
                    # Run the query.
                    res = vexecute.run(document.text, self.vspecial)

                    file_output = None
                    stdout_output = []

                    for title, cur, headers, status, force_stdout in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            click.secho(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        formatted = format_output(
                            title, cur, headers, status, self.table_format,
                            self.vspecial.expanded_output,
                            self.vspecial.aligned, self.vspecial.show_header)

                        if self.vspecial.output is not sys.stdout:
                            file_output = self.vspecial.output

                        if force_stdout or not file_output:
                            output = stdout_output
                        else:
                            output = file_output

                        write_output(output, formatted)

                        if hasattr(cur, 'rowcount'):
                            if self.vspecial.show_header:
                                if cur.rowcount == 1:
                                    write_output(output, '(1 row)')
                                elif headers:
                                    rowcount = max(cur.rowcount, 0)
                                    write_output(output,
                                                 '(%d rows)' % rowcount)
                            if document.text.startswith(
                                    '\\') and cur.rowcount == 0:
                                stdout_output = [
                                    'No matching relations found.'
                                ]

                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    vexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except errors.ConnectionError as e:
                    reconnect = click.prompt(
                        'Connection reset. Reconnect (Y/n)',
                        show_default=False,
                        type=bool,
                        default=True)
                    if reconnect:
                        try:
                            vexecute.connect()
                            click.secho('Reconnected!\nTry the command again.',
                                        fg='green')
                        except errors.DatabaseError as e:
                            click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(format_error(e), err=True, fg='red')
                else:
                    successful = True
                    if stdout_output:
                        output = '\n'.join(stdout_output)
                        try:
                            click.echo_via_pager(output)
                        except KeyboardInterrupt:
                            pass

                    if file_output:
                        try:
                            file_output.flush()
                        except KeyboardInterrupt:
                            pass
                    if self.vspecial.timing_enabled:
                        print('Time: %0.03fs' % (time() - start))

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(
                            need_completion_reset(document.text))

                    # Refresh search_path to set default schema.
                    if need_search_path_refresh(document.text):
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(
                                vexecute.search_path())
                        logger.debug('Search path: %r',
                                     self.completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-RXF'

        return less_opts

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.vexecute, self.vspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                 'Auto-completion refresh started in the background.', True)]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #22
0
class Saws(object):
    """Encapsulates the Saws CLI.

    Attributes:
        * aws_cli: An instance of prompt_toolkit's CommandLineInterface.
        * key_manager: An instance of KeyManager.
        * config: An instance of Config.
        * config_obj: An instance of ConfigObj, reads from ~/.sawsrc.
        * theme: A string representing the lexer theme.
        * logger: An instance of SawsLogger.
        * all_commands: A list of all commands, sub_commands, options, etc
            from data/SOURCES.txt.
        * commands: A list of commands from data/SOURCES.txt.
        * sub_commands: A list of sub_commands from data/SOURCES.txt.
        * completer: An instance of AwsCompleter.
    """

    PYGMENTS_CMD = ' | pygmentize -l json'

    def __init__(self, refresh_resources=True):
        """Inits Saws.

        Args:
            * refresh_resources: A boolean that determines whether to
                refresh resources.

        Returns:
            None.
        """
        self.aws_cli = None
        self.key_manager = None
        self.config = Config()
        self.config_obj = self.config.read_configuration()
        self.theme = self.config_obj[self.config.MAIN][self.config.THEME]
        self.logger = SawsLogger(
            __name__,
            self.config_obj[self.config.MAIN][self.config.LOG_FILE],
            self.config_obj[self.config.MAIN][self.config.LOG_LEVEL]).logger
        self.all_commands = AwsCommands().all_commands
        self.commands = \
            self.all_commands[AwsCommands.CommandType.COMMANDS.value]
        self.sub_commands = \
            self.all_commands[AwsCommands.CommandType.SUB_COMMANDS.value]
        self.completer = AwsCompleter(
            awscli_completer,
            self.all_commands,
            self.config,
            self.config_obj,
            self.log_exception,
            fuzzy_match=self.get_fuzzy_match(),
            shortcut_match=self.get_shortcut_match())
        if refresh_resources:
            self.completer.refresh_resources_and_options()
        self._create_cli()

    def log_exception(self, e, traceback, echo=False):
        """Logs the exception and traceback to the log file ~/.saws.log.

        Args:
            * e: A Exception that specifies the exception.
            * traceback: A Traceback that specifies the traceback.
            * echo: A boolean that specifies whether to echo the exception
                to the console using click.

        Returns:
            None.
        """
        self.logger.debug('exception: %r.', str(e))
        self.logger.error("traceback: %r", traceback.format_exc())
        if echo:
            click.secho(str(e), fg='red')

    def set_color(self, color):
        """Setter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the color flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.COLOR] = color

    def get_color(self):
        """Getter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the color flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.COLOR)

    def set_fuzzy_match(self, fuzzy):
        """Setter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the fuzzy flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.FUZZY] = fuzzy
        self.completer.fuzzy_match = fuzzy

    def get_fuzzy_match(self):
        """Getter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the fuzzy flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.FUZZY)

    def set_shortcut_match(self, shortcut):
        """Setter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the shortcut flag.

        Returns:
            None.
        """
        self.config_obj[self.config.MAIN][self.config.SHORTCUT] = shortcut
        self.completer.shortcut_match = shortcut

    def get_shortcut_match(self):
        """Getter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the shortcut flag.
        """
        return self.config_obj[self.config.MAIN].as_bool(self.config.SHORTCUT)

    def refresh_resources_and_options(self):
        """Convenience function to refresh resources and options for completion.

        Used by prompt_toolkit's KeyBindingManager.

        Args:
            * None.

        Returns:
            None.
        """
        self.completer.refresh_resources_and_options(force_refresh=True)

    def handle_docs(self, text=None, from_fkey=False):
        """Displays contextual web docs for `F9` or the `docs` command.

        Displays the web docs specific to the currently entered:

        * (optional) command
        * (optional) subcommand

        If no command or subcommand is present, the docs index page is shown.

        Docs are only displayed if:

        * from_fkey is True
        * from_fkey is False and `docs` is found in text

        Args:
            * text: A string representing the input command text.
            * from_fkey: A boolean representing whether this function is
                being executed from an `F9` key press.

        Returns:
            A boolean representing whether the web docs were shown.
        """
        base_url = 'http://docs.aws.amazon.com/cli/latest/reference/'
        index_html = 'index.html'
        if text is None:
            text = self.aws_cli.current_buffer.document.text
        # If the user hit the F9 key, append 'docs' to the text
        if from_fkey:
            text = text.strip() + ' ' + AwsCommands.AWS_DOCS
        tokens = text.split()
        if len(tokens) > 2 and tokens[-1] == AwsCommands.AWS_DOCS:
            prev_word = tokens[-2]
            # If we have a command, build the url
            if prev_word in self.commands:
                prev_word = prev_word + '/'
                url = base_url + prev_word + index_html
                webbrowser.open(url)
                return True
            # if we have a command and subcommand, build the url
            elif prev_word in self.sub_commands:
                command_url = tokens[-3] + '/'
                sub_command_url = tokens[-2] + '.html'
                url = base_url + command_url + sub_command_url
                webbrowser.open(url)
                return True
            webbrowser.open(base_url + index_html)
        # If we still haven't opened the help doc at this point and the
        # user hit the F9 key or typed docs, just open the main docs index
        if from_fkey or AwsCommands.AWS_DOCS in tokens:
            webbrowser.open(base_url + index_html)
            return True
        return False

    def _handle_cd(self, text):
        """Handles a `cd` shell command by calling python's os.chdir.

        Simply passing in the `cd` command to subprocess.call doesn't work.
        Note: Changing the directory within Saws will only be in effect while
        running Saws.  Exiting the program will return you to the directory
        you were in prior to running Saws.

        Attributes:
            * text: A string representing the input command text.

        Returns:
            A boolean representing a `cd` command was found and handled.
        """
        CD_CMD = 'cd'
        stripped_text = text.strip()
        if stripped_text.startswith(CD_CMD):
            directory = ''
            if stripped_text == CD_CMD:
                # Treat `cd` as a change to the root directory.
                # os.path.expanduser does this in a cross platform manner.
                directory = os.path.expanduser('~')
            else:
                tokens = text.split(CD_CMD + ' ')
                directory = tokens[-1]
            try:
                os.chdir(directory)
            except OSError as e:
                self.log_exception(e, traceback, echo=True)
            return True
        return False

    def _colorize_output(self, text):
        """Highlights output with pygments.

        Only highlights the output if all of the following conditions are True:

        * The color option is enabled
        * The command will be handled by the `aws-cli`
        * The text does not contain the `configure` command
        * The text does not contain the `help` command, which already does
            output highlighting

        Args:
            * text: A string that represents the input command text.

        Returns:
            A string that represents:
                * The original command text if no highlighting was performed.
                * The pygments highlighted command text otherwise.
        """
        stripped_text = text.strip()
        if not self.get_color() or stripped_text == '':
            return text
        if AwsCommands.AWS_COMMAND not in stripped_text.split():
            return text
        excludes = [AwsCommands.AWS_CONFIGURE,
                    AwsCommands.AWS_HELP,
                    '|']
        if not any(substring in stripped_text for substring in excludes):
            return text.strip() + self.PYGMENTS_CMD
        else:
            return text

    def _handle_keyboard_interrupt(self, e, platform):
        """Handles keyboard interrupts more gracefully on Mac/Unix/Linux.

        Allows Mac/Unix/Linux to continue running on keyboard interrupt,
        as the user might interrupt a long-running AWS command with Control-C
        while continuing to work with Saws.

        On Windows, the "Terminate batch job (Y/N)" confirmation makes it
        tricky to handle this gracefully.  Thus, we re-raise KeyboardInterrupt.

        Args:
            * e: A KeyboardInterrupt.
            * platform: A string that denotes platform such as
                'Windows', 'Darwin', etc.

        Returns:
            None

        Raises:
            Exception: A KeyboardInterrupt if running on Windows.
        """
        if platform == 'Windows':
            raise e
        else:
            # Clear the renderer and send a carriage return
            self.aws_cli.renderer.clear()
            self.aws_cli.input_processor.feed(KeyPress(Keys.ControlM, u''))
            self.aws_cli.input_processor.process_keys()

    def _process_command(self, text):
        """Processes the input command, called by the cli event loop

        Args:
            * text: A string that represents the input command text.

        Returns:
            None.
        """
        if AwsCommands.AWS_COMMAND in text:
            text = self.completer.replace_shortcut(text)
            if self.handle_docs(text):
                return
        try:
            if not self._handle_cd(text):
                text = self._colorize_output(text)
                # Pass the command onto the shell so aws-cli can execute it
                subprocess.call(text, shell=True)
            print('')
        except KeyboardInterrupt as e:
            self._handle_keyboard_interrupt(e, platform.system())
        except Exception as e:
            self.log_exception(e, traceback, echo=True)

    def _create_cli(self):
        """Creates the prompt_toolkit's CommandLineInterface.

        Args:
            * None.

        Returns:
            None.
        """
        history = FileHistory(os.path.expanduser('~/.saws-history'))
        toolbar = Toolbar(self.get_color,
                          self.get_fuzzy_match,
                          self.get_shortcut_match)
        layout = create_default_layout(
            message='saws> ',
            reserve_space_for_menu=8,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar.handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )
        cli_buffer = Buffer(
            history=history,
            auto_suggest=AutoSuggestFromHistory(),
            enable_history_search=True,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)
        self.key_manager = KeyManager(
            self.set_color,
            self.get_color,
            self.set_fuzzy_match,
            self.get_fuzzy_match,
            self.set_shortcut_match,
            self.get_shortcut_match,
            self.refresh_resources_and_options,
            self.handle_docs)
        style_factory = StyleFactory(self.theme)
        application = Application(
            mouse_support=False,
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)
        eventloop = create_eventloop()
        self.aws_cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def run_cli(self):
        """Runs the main loop.

        Args:
            * None.

        Returns:
            None.
        """
        print('Version:', __version__)
        print('Theme:', self.theme)
        while True:
            document = self.aws_cli.run(reset_current_buffer=True)
            self._process_command(document.text)
Example #23
0
def loop(cmd, history_file):
    def session_toolbar(cli):
        return _get_toolbar_tokens(cmd.is_conn_available, cmd.username,
                                   cmd.connection.client.active_servers)

    key_binding_manager = KeyBindingManager(
        enable_search=True,
        enable_abort_and_exit_bindings=True,
        enable_system_bindings=True,
        enable_open_in_editor=True)
    bind_keys(key_binding_manager.registry)
    layout = create_layout(
        message='cr> ',
        multiline=True,
        lexer=SqlLexer,
        extra_input_processors=[
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ],
        get_bottom_toolbar_tokens=session_toolbar)
    application = Application(
        layout=layout,
        buffer=create_buffer(cmd, history_file),
        style=PygmentsStyle.from_defaults(pygments_style_cls=CrateStyle),
        key_bindings_registry=key_binding_manager.registry,
        editing_mode=_get_editing_mode(),
        on_exit=AbortAction.RAISE_EXCEPTION,
        on_abort=AbortAction.RETRY,
    )
    eventloop = create_eventloop()
    output = create_output()
    cli = CommandLineInterface(application=application,
                               eventloop=eventloop,
                               output=output)

    def get_num_columns_override():
        return output.get_size().columns

    cmd.get_num_columns = get_num_columns_override

    while True:
        try:
            doc = cli.run(reset_current_buffer=True)
            if doc:
                cmd.process(doc.text)
        except ProgrammingError as e:
            if '401' in e.message:
                username = cmd.username
                password = cmd.password
                cmd.username = input('Username: '******'Bye!')
            return
Example #24
0
class Haxor(object):
    """Encapsulate the Hacker News CLI.

    :type cli: :class:`prompt_toolkit.CommandLineInterface`
    :param cli: An instance of `prompt_toolkit.CommandLineInterface`.

    :type CMDS_ENABLE_PAGINATE: list (const)
    :param CMDS_ENABLE_PAGINATE: A list of commands that kick off pagination.

    :type CMDS_NO_PAGINATE: list (const)
    :param CMDS_NO_PAGINATE: A list of commands that disable pagination.

    :type completer: :class:`prompt_toolkit.completer`
    :param completer: An instance of `prompt_toolkit.completer`.

    :type hacker_news_cli: :class:`hacker_news_cli.HackerNewsCli`
    :param hacker_news_cli: An instance of `hacker_news_cli.HackerNewsCli`.

    :type key_manager: :class:`prompt_toolkit.key_binding.manager.
        KeyBindingManager`
    :param key_manager: An instance of `prompt_toolkit.key_binding.manager.
        KeyBindingManager`.

    :type PAGINATE_CMD: str (const)
    :param PAGINATE_CMD: The command to enable pagination.

    :type paginate_comments: bool
    :param paginate_comments: Determines whether to paginate
            comments.

    :type text_utils: :class:`util.TextUtils`
    :param text_utils: An instance of `util.TextUtils`.

    :type theme: str
    :param theme: The prompt_toolkit lexer theme.
    """

    CMDS_NO_PAGINATE = [
        '-b',
        '--browser',
        '>',
        '<',
    ]
    CMDS_ENABLE_PAGINATE = [
        '-cq',
        '--comments_regex_query',
        '-c',
        '--comments',
        '-cr',
        '--comments_recent',
        '-cu',
        '--comments_unseen',
        '-ch',
        '--comments_hide_non_matching',
        'hiring',
        'freelance',
    ]
    PAGINATE_CMD = ' | less -r'
    PAGINATE_CMD_WIN = ' | more'

    def __init__(self):
        self.cli = None
        self.key_manager = None
        self.theme = 'vim'
        self.paginate_comments = True
        self.hacker_news_cli = HackerNewsCli()
        self.text_utils = TextUtils()
        self.completer = Completer(fuzzy_match=False,
                                   text_utils=self.text_utils)
        self._create_cli()
        if platform.system() == 'Windows':
            self.CMDS_ENABLE_PAGINATE.append('view')

    def _create_key_manager(self):
        """Create the :class:`KeyManager`.

        The inputs to KeyManager are expected to be callable, so we can't
        use the standard @property and @attrib.setter for these attributes.
        Lambdas cannot contain assignments so we're forced to define setters.

        :rtype: :class:`prompt_toolkit.key_binding.manager
        :return: KeyBindingManager with callables to set the toolbar options.
        """

        def set_paginate_comments(paginate_comments):
            """Setter for paginating comments mode.

            :type paginate: bool
            :param paginate: The paginate comments mode.
            """
            self.paginate_comments = paginate_comments

        return KeyManager(
            set_paginate_comments, lambda: self.paginate_comments)

    def _create_cli(self):
        """Create the prompt_toolkit's CommandLineInterface."""
        history = FileHistory(os.path.expanduser('~/.haxornewshistory'))
        toolbar = Toolbar(lambda: self.paginate_comments)
        layout = create_default_layout(
            message=u'haxor> ',
            reserve_space_for_menu=True,
            get_bottom_toolbar_tokens=toolbar.handler,
        )
        cli_buffer = Buffer(
            history=history,
            auto_suggest=AutoSuggestFromHistory(),
            enable_history_search=True,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)
        self.key_manager = self._create_key_manager()
        style_factory = StyleFactory(self.theme)
        application = Application(
            mouse_support=False,
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)
        eventloop = create_eventloop()
        self.cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def _add_comment_pagination(self, document_text):
        """Add the command to enable comment pagination where applicable.

        Pagination is enabled if the command views comments and the
        browser flag is not enabled.

        :type document_text: str
        :param document_text: The input command.

        :rtype: str
        :return: the input command with pagination enabled.
        """
        if not any(sub in document_text for sub in self.CMDS_NO_PAGINATE):
            if any(sub in document_text for sub in self.CMDS_ENABLE_PAGINATE):
                if platform.system() == 'Windows':
                    document_text += self.PAGINATE_CMD_WIN
                else:
                    document_text += self.PAGINATE_CMD
        return document_text

    def handle_exit(self, document):
        """Exits if the user typed exit or quit

        :type document: :class:`prompt_toolkit.document.Document`
        :param document: An instance of `prompt_toolkit.document.Document`.
        """
        if document.text in ('exit', 'quit'):
            sys.exit()

    def run_command(self, document):
        """Run the given command.

        :type document: :class:`prompt_toolkit.document.Document`
        :param document: An instance of `prompt_toolkit.document.Document`.
        """
        try:
            if self.paginate_comments:
                document.text = self._add_comment_pagination(document.text)
            subprocess.call(document.text, shell=True)
        except Exception as e:
            click.secho(e, fg='red')

    def run_cli(self):
        """Run the main loop."""
        click.echo('Version: ' + __version__)
        click.echo('Syntax: hn <command> [params] [options]')
        while True:
            document = self.cli.run()
            self.handle_exit(document)
            self.run_command(document)
Example #25
0
class Haxor(object):
    """Encapsulate the Hacker News CLI.

    :type cli: :class:`prompt_toolkit.CommandLineInterface`
    :param cli: An instance of `prompt_toolkit.CommandLineInterface`.

    :type CMDS_ENABLE_PAGINATE: list (const)
    :param CMDS_ENABLE_PAGINATE: A list of commands that kick off pagination.

    :type CMDS_NO_PAGINATE: list (const)
    :param CMDS_NO_PAGINATE: A list of commands that disable pagination.

    :type completer: :class:`prompt_toolkit.completer`
    :param completer: An instance of `prompt_toolkit.completer`.

    :type hacker_news_cli: :class:`hacker_news_cli.HackerNewsCli`
    :param hacker_news_cli: An instance of `hacker_news_cli.HackerNewsCli`.

    :type key_manager: :class:`prompt_toolkit.key_binding.manager.
        KeyBindingManager`
    :param key_manager: An instance of `prompt_toolkit.key_binding.manager.
        KeyBindingManager`.

    :type PAGINATE_CMD: str (const)
    :param PAGINATE_CMD: The command to enable pagination.

    :type paginate_comments: bool
    :param paginate_comments: Determines whether to paginate
            comments.

    :type text_utils: :class:`util.TextUtils`
    :param text_utils: An instance of `util.TextUtils`.

    :type theme: str
    :param theme: The prompt_toolkit lexer theme.
    """

    CMDS_NO_PAGINATE = [
        '-b',
        '--browser',
        '>',
        '<',
    ]
    CMDS_ENABLE_PAGINATE = [
        '-cq',
        '--comments_regex_query',
        '-c',
        '--comments',
        '-cr',
        '--comments_recent',
        '-cu',
        '--comments_unseen',
        '-ch',
        '--comments_hide_non_matching',
        'hiring',
        'freelance',
    ]
    PAGINATE_CMD = ' | less -r'
    PAGINATE_CMD_WIN = ' | more'

    def __init__(self):
        self.cli = None
        self.key_manager = None
        self.theme = 'vim'
        self.paginate_comments = True
        self.hacker_news_cli = HackerNewsCli()
        self.text_utils = TextUtils()
        self.completer = Completer(fuzzy_match=False,
                                   text_utils=self.text_utils)
        self._create_cli()
        if platform.system() == 'Windows':
            self.CMDS_ENABLE_PAGINATE.append('view')

    def _create_key_manager(self):
        """Create the :class:`KeyManager`.

        The inputs to KeyManager are expected to be callable, so we can't
        use the standard @property and @attrib.setter for these attributes.
        Lambdas cannot contain assignments so we're forced to define setters.

        :rtype: :class:`prompt_toolkit.key_binding.manager
        :return: KeyBindingManager with callables to set the toolbar options.
        """

        def set_paginate_comments(paginate_comments):
            """Setter for paginating comments mode.

            :type paginate: bool
            :param paginate: The paginate comments mode.
            """
            self.paginate_comments = paginate_comments

        return KeyManager(
            set_paginate_comments, lambda: self.paginate_comments)

    def _create_cli(self):
        """Create the prompt_toolkit's CommandLineInterface."""
        history = FileHistory(os.path.expanduser('~/.haxornewshistory'))
        toolbar = Toolbar(lambda: self.paginate_comments)
        layout = create_default_layout(
            message=u'haxor> ',
            reserve_space_for_menu=8,
            get_bottom_toolbar_tokens=toolbar.handler,
        )
        cli_buffer = Buffer(
            history=history,
            auto_suggest=AutoSuggestFromHistory(),
            enable_history_search=True,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)
        self.key_manager = self._create_key_manager()
        style_factory = StyleFactory(self.theme)
        application = Application(
            mouse_support=False,
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)
        eventloop = create_eventloop()
        self.cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def _add_comment_pagination(self, document_text):
        """Add the command to enable comment pagination where applicable.

        Pagination is enabled if the command views comments and the
        browser flag is not enabled.

        :type document_text: str
        :param document_text: The input command.

        :rtype: str
        :return: the input command with pagination enabled.
        """
        if not any(sub in document_text for sub in self.CMDS_NO_PAGINATE):
            if any(sub in document_text for sub in self.CMDS_ENABLE_PAGINATE):
                if platform.system() == 'Windows':
                    document_text += self.PAGINATE_CMD_WIN
                else:
                    document_text += self.PAGINATE_CMD
        return document_text

    def handle_exit(self, document):
        """Exits if the user typed exit or quit

        :type document: :class:`prompt_toolkit.document.Document`
        :param document: An instance of `prompt_toolkit.document.Document`.
        """
        if document.text in ('exit', 'quit'):
            sys.exit()

    def run_command(self, document):
        """Run the given command.

        :type document: :class:`prompt_toolkit.document.Document`
        :param document: An instance of `prompt_toolkit.document.Document`.
        """
        try:
            if self.paginate_comments:
                text = document.text
                text = self._add_comment_pagination(text)
            subprocess.call(text, shell=True)
        except Exception as e:
            click.secho(e, fg='red')

    def run_cli(self):
        """Run the main loop."""
        click.echo('Version: ' + __version__)
        click.echo('Syntax: hn <command> [params] [options]')
        while True:
            document = self.cli.run(reset_current_buffer=True)
            self.handle_exit(document)
            self.run_command(document)
Example #26
0
    def run(self, query=None, data=None):
        self.load_config()

        if data is not None or query is not None:
            self.format = self.format_stdin
            self.echo.verbose = False

        if self.echo.verbose:
            show_version()

        if not self.connect():
            return

        if data is not None and query is None:
            # cat stuff.sql | clickhouse-cli
            return self.handle_input('\n'.join(data), verbose=False)

        if data is None and query is not None:
            # clickhouse-cli -q 'SELECT 1'
            return self.handle_query(query, stream=True)

        if data is not None and query is not None:
            # cat stuff.csv | clickhouse-cli -q 'INSERT INTO stuff'
            return self.handle_query(query, data=data, stream=True)

        layout = create_prompt_layout(
            lexer=PygmentsLexer(CHLexer),
            get_prompt_tokens=get_prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            multiline=self.multiline,
        )

        buffer = CLIBuffer(
            client=self.client,
            multiline=self.multiline,
        )

        application = Application(
            layout=layout,
            buffer=buffer,
            style=CHStyle,
            key_bindings_registry=KeyBinder.registry,
        )

        eventloop = create_eventloop()

        cli = CommandLineInterface(application=application, eventloop=eventloop)

        try:
            while True:
                try:
                    cli_input = cli.run(reset_current_buffer=True)
                    self.handle_input(cli_input.text)
                except KeyboardInterrupt:
                    # Attempt to terminate queries
                    for query_id in self.query_ids:
                        self.client.kill_query(query_id)

                    self.echo.error("\nQuery was terminated.")
                finally:
                    self.query_ids = []
        except EOFError:
            self.echo.success("Bye.")
Example #27
0
class Saws(object):
    """Encapsulates the Saws CLI.

    Attributes:
        * aws_cli: An instance of prompt_toolkit's CommandLineInterface.
        * config: An instance of Config.
        * config_obj: An instance of ConfigObj, reads from ~/.sawsrc.
        * aws_commands: An instance of AwsCommands
        * commands: A list of commands from data/SOURCES.txt.
        * sub_commands: A list of sub_commands from data/SOURCES.txt.
        * global_options: A list of global_options from data/SOURCES.txt.
        * resource_options: A list of resource_options from data/SOURCES.txt,
            used for syntax coloring.
        * ec2_states: A list of ec2_states from data/SOURCES.txt.
        * completer: An instance of AwsCompleter.
        * key_manager: An instance of KeyManager
        * logger: An instance of SawsLogger.
        * theme: A string representing the lexer theme.
            Currently only 'vim' is supported.
    """

    def __init__(self):
        """Inits Saws.

        Args:
            * None.

        Returns:
            None.
        """
        self.aws_cli = None
        self.key_manager = None
        self.PYGMENTS_CMD = ' | pygmentize -l json'
        self.config = Config()
        self.config_obj = self.config.read_configuration()
        self.theme = self.config_obj['main']['theme']
        self.logger = SawsLogger(__name__,
                                 self.config_obj['main']['log_file'],
                                 self.config_obj['main']['log_level']).logger
        self.aws_commands = AwsCommands()
        self.commands, self.sub_commands, self.global_options, \
            self.resource_options, self.ec2_states \
            = self.aws_commands.generate_all_commands()
        self.completer = AwsCompleter(
            awscli_completer,
            self.commands,
            self.config_obj,
            self.log_exception,
            ec2_states=self.ec2_states,
            fuzzy_match=self.get_fuzzy_match(),
            shortcut_match=self.get_shortcut_match())
        self.create_cli()

    def log_exception(self, e, traceback, echo=False):
        """Logs the exception and traceback to the log file ~/.saws.log.

        Args:
            * e: A Exception that specifies the exception.
            * traceback: A Traceback that specifies the traceback.
            * echo: A boolean that specifies whether to echo the exception
                to the console using click.

        Returns:
            None.
        """
        self.logger.debug('exception: %r.', str(e))
        self.logger.error("traceback: %r", traceback.format_exc())
        if echo:
            click.secho(str(e), fg='red')

    def set_color(self, color):
        """Setter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the color flag.

        Returns:
            None.
        """
        self.config_obj['main']['color_output'] = color

    def get_color(self):
        """Getter for color output mode.

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the color flag.
        """
        return self.config_obj['main'].as_bool('color_output')

    def set_fuzzy_match(self, fuzzy):
        """Setter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the fuzzy flag.

        Returns:
            None.
        """
        self.config_obj['main']['fuzzy_match'] = fuzzy
        self.completer.fuzzy_match = fuzzy

    def get_fuzzy_match(self):
        """Getter for fuzzy matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the fuzzy flag.
        """
        return self.config_obj['main'].as_bool('fuzzy_match')

    def set_shortcut_match(self, shortcut):
        """Setter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * color: A boolean that represents the shortcut flag.

        Returns:
            None.
        """
        self.config_obj['main']['shortcut_match'] = shortcut
        self.completer.shortcut_match = shortcut

    def get_shortcut_match(self):
        """Getter for shortcut matching mode

        Used by prompt_toolkit's KeyBindingManager.
        KeyBindingManager expects this function to be callable so we can't use
        @property and @attrib.setter.

        Args:
            * None.

        Returns:
            A boolean that represents the shortcut flag.
        """
        return self.config_obj['main'].as_bool('shortcut_match')

    def refresh_resources(self):
        """Convenience function to refresh resources for completion.

        Used by prompt_toolkit's KeyBindingManager.

        Args:
            * None.

        Returns:
            None.
        """
        self.completer.refresh_resources(force_refresh=True)

    def handle_docs(self, text=None, from_fkey=False):
        """Displays contextual web docs for `F9` or the `docs` command.

        Displays the web docs specific to the currently entered:

        * (optional) command
        * (optional) subcommand

        If no command or subcommand is present, the docs index page is shown.

        Docs are only displayed if:

        * from_fkey is True
        * from_fkey is False and `docs` is found in text

        Args:
            * text: A string representing the input command text.
            * from_fkey: A boolean representing whether this function is
                being executed from an `F9` key press.

        Returns:
            A boolean representing whether the web docs were shown.
        """
        base_url = 'http://docs.aws.amazon.com/cli/latest/reference/'
        index_html = 'index.html'
        if text is None:
            text = self.aws_cli.current_buffer.document.text
        # If the user hit the F9 key, append 'docs' to the text
        if from_fkey:
            text = text.strip() + ' ' + AwsCommands.AWS_DOCS
        tokens = text.split()
        if len(tokens) > 2 and tokens[-1] == AwsCommands.AWS_DOCS:
            prev_word = tokens[-2]
            # If we have a command, build the url
            if prev_word in self.commands:
                prev_word = prev_word + '/'
                url = base_url + prev_word + index_html
                webbrowser.open(url)
                return True
            # if we have a command and subcommand, build the url
            elif prev_word in self.sub_commands:
                command_url = tokens[-3] + '/'
                sub_command_url = tokens[-2] + '.html'
                url = base_url + command_url + sub_command_url
                webbrowser.open(url)
                return True
            webbrowser.open(base_url + index_html)
        # If we still haven't opened the help doc at this point and the
        # user hit the F9 key or typed docs, just open the main docs index
        if from_fkey or AwsCommands.AWS_DOCS in tokens:
            webbrowser.open(base_url + index_html)
            return True
        return False

    def handle_cd(self, text):
        """Handles a `cd` shell command by calling python's os.chdir.

        Simply passing in the `cd` command to subprocess.call doesn't work.
        Note: Changing the directory within Saws will only be in effect while
        running Saws.  Exiting the program will return you to the directory
        you were in prior to running Saws.

        Attributes:
            * text: A string representing the input command text.

        Returns:
            A boolean representing a `cd` command was found and handled.
        """
        CD_CMD = 'cd'
        stripped_text = text.strip()
        if stripped_text.startswith(CD_CMD):
            directory = ''
            if stripped_text == CD_CMD:
                # Treat `cd` as a change to the root directory.
                # os.path.expanduser does this in a cross platform manner.
                directory = os.path.expanduser('~')
            else:
                tokens = text.split(CD_CMD + ' ')
                directory = tokens[-1]
            try:
                os.chdir(directory)
            except OSError as e:
                self.log_exception(e, traceback, echo=True)
            return True
        return False

    def colorize_output(self, text):
        """Highlights output with pygments.

        Only highlights the output if all of the following conditions are True:

        * The color option is enabled
        * The text does not contain the `configure` command
        * The text does not contain the `help` command, which already does
            output highlighting

        Args:
            * text: A string that represents the input command text.

        Returns:
            A string that represents:
                * The original command text if no highlighting was performed.
                * The pygments highlighted command text otherwise.
        """
        if not self.get_color():
            return text
        stripped_text = text.strip()
        excludes = [AwsCommands.AWS_CONFIGURE,
                    AwsCommands.AWS_HELP,
                    '|']
        if not any(substring in stripped_text for substring in excludes):
            return text.strip() + self.PYGMENTS_CMD
        else:
            return text

    def process_command(self, text):
        """Processes the input command, called by the cli event loop

        Args:
            * text: A string that represents the input command text.

        Returns:
            None.
        """
        if AwsCommands.AWS_COMMAND in text:
            text = self.completer.replace_shortcut(text)
            if self.handle_docs(text):
                return
            text = self.colorize_output(text)
        try:
            if not self.handle_cd(text):
                # Pass the command onto the shell so aws-cli can execute it
                subprocess.call(text, shell=True)
            print('')
        except Exception as e:
            self.log_exception(e, traceback, echo=True)

    def create_cli(self):
        """Creates the prompt_toolkit's CommandLineInterface.

        Long description.

        Args:
            * None.

        Returns:
            None.
        """
        history = FileHistory(os.path.expanduser('~/.saws-history'))
        toolbar = Toolbar(self.get_color,
                          self.get_fuzzy_match,
                          self.get_shortcut_match)
        layout = create_default_layout(
            message='saws> ',
            reserve_space_for_menu=True,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar.handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )
        cli_buffer = Buffer(
            history=history,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT)
        self.key_manager = KeyManager(
            self.set_color,
            self.get_color,
            self.set_fuzzy_match,
            self.get_fuzzy_match,
            self.set_shortcut_match,
            self.get_shortcut_match,
            self.refresh_resources,
            self.handle_docs)
        style_factory = StyleFactory(self.theme)
        application = Application(
            mouse_support=False,
            style=style_factory.style,
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=self.key_manager.manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            on_abort=AbortAction.RETRY,
            ignore_case=True)
        eventloop = create_eventloop()
        self.aws_cli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

    def run_cli(self):
        """Runs the main loop.

        Args:
            * None.

        Returns:
            None.
        """
        print('Version:', __version__)
        print('Theme:', self.theme)
        while True:
            document = self.aws_cli.run()
            self.process_command(document.text)
Example #28
0
class NetWork(object):
    def __init__(self):
        self.fuzzy = True
        self.long = True
        self.cmd = NetWorkCommands()
        self._create_cli()

    def set_fuzzy_match(self, is_fuzzy):
        self.fuzzy = is_fuzzy
        self.completer.set_fuzzy_match(is_fuzzy)

    def get_fuzzy_match(self):
        return self.fuzzy

    def set_long_options(self, is_long):
        self.long = is_long
        self.completer.set_long_options(is_long)

    def get_long_options(self):
        return self.long

    def _process_command(self, text):
        self.cmd.run_cmd(text)

    def _create_layout(self):
        self.layout = create_default_layout(
            message='network > ',
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=create_toolbar_handler(
                self.get_long_options, self.get_fuzzy_match))

    def _create_completer(self):
        self.completer = NetworkCompleter()

    def _create_buffer(self):
        self._create_completer()
        self.buffer = Buffer(
            history=FileHistory(os.path.expanduser('./network-history')),
            auto_suggest=AutoSuggestFromHistory(),
            enable_history_search=True,
            completer=self.completer,
            complete_while_typing=Always(),
            accept_action=AcceptAction.RETURN_DOCUMENT,
        )

    def _create_style(self):
        self.style = style_factory('red')

    def _create_manage(self):
        self.manager = get_key_manager(self.set_long_options,
                                       self.get_long_options,
                                       self.set_fuzzy_match,
                                       self.get_fuzzy_match)

    def _create_cli(self):
        self._create_layout()
        self._create_buffer()
        self._create_style()
        self._create_manage()

        application = Application(layout=self.layout,
                                  buffer=self.buffer,
                                  style=self.style,
                                  key_bindings_registry=self.manager.registry,
                                  mouse_support=False,
                                  on_exit=AbortAction.RAISE_EXCEPTION,
                                  on_abort=AbortAction.RETRY,
                                  ignore_case=True)
        event_loop = create_eventloop()
        self.network_cli = CommandLineInterface(application=application,
                                                eventloop=event_loop)

    def _quit_command(self, text):
        return (text.strip().lower() == 'exit'
                or text.strip().lower() == 'quit' or text.strip() == r'\q'
                or text.strip() == ':q')

    def run_cli(self):
        print('Version:', __version__)
        while True:
            document = self.network_cli.run(reset_current_buffer=True)
            if self._quit_command(document.text):
                raise EOFError
            try:
                self._process_command(document.text)
            except KeyError as ex:
                click.secho(ex.message, fg='red')
            except NotImplementedError as ex:
                click.secho(ex.message, fg='red')
Example #29
0
class WharfeeCli(object):
    """
    The CLI implementation.
    """

    dcli = None
    keyword_completer = None
    handler = None
    saved_less_opts = None
    config = None
    config_template = 'wharfeerc'
    config_name = '~/.wharfeerc'

    def __init__(self):
        """
        Initialize class members.
        Should read the config here at some point.
        """

        self.config = self.read_configuration()
        self.theme = self.config['main']['theme']

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']
        self.logger = create_logger(__name__, log_file, log_level)
        
        # set_completer_options refreshes all by default
        self.handler = DockerClient(
            self.config['main'].as_int('client_timeout'),
            self.clear,
            self.set_completer_options)

        self.completer = DockerCompleter(
            long_option_names=self.get_long_options(),
            fuzzy=self.get_fuzzy_match())
        self.set_completer_options()
        self.saved_less_opts = self.set_less_opts()

    def read_configuration(self):
        """

        :return:
        """
        default_config = os.path.join(
            self.get_package_path(), self.config_template)
        write_default_config(default_config, self.config_name)
        return read_config(self.config_name, default_config)

    def get_package_path(self):
        """
        Find out pakage root path.
        :return: string: path
        """
        from wharfee import __file__ as package_root
        return os.path.dirname(package_root)

    def set_less_opts(self):
        """
        Set the "less" options and save the old settings.

        What we're setting:
          -F:
            --quit-if-one-screen: Quit if entire file fits on first screen.
          -R:
            --raw-control-chars: Output "raw" control characters.
          -X:
            --no-init: Don't use termcap keypad init/deinit strings.
            --no-keypad: Don't use termcap init/deinit strings.
            This also disables launching "less" in an alternate screen.

        :return: string with old options
        """
        opts = os.environ.get('LESS', '')
        os.environ['LESS'] = '-RXF'
        return opts

    def revert_less_opts(self):
        """
        Restore the previous "less" options.
        """
        os.environ['LESS'] = self.saved_less_opts

    def write_config_file(self):
        """
        Write config file on exit.
        """
        self.config.write()

    def clear(self):
        """
        Clear the screen.
        """
        click.clear()

    def set_completer_options(self, cons=True, runs=True, imgs=True):
        """
        Set image and container names in Completer.
        Re-read if needed after a command.
        :param cons: boolean: need to refresh containers
        :param runs: boolean: need to refresh running containers
        :param imgs: boolean: need to refresh images
        """

        if cons:
            cs = self.handler.containers(all=True)
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                containers = [name for c in cs for name in c['Names']]
                self.completer.set_containers(containers)

        if runs:
            cs = self.handler.containers()
            if cs and len(cs) > 0 and isinstance(cs[0], dict):
                running = [name for c in cs for name in c['Names']]
                self.completer.set_running(running)

        if imgs:
            def format_tagged(tagname, img_id):
                if tagname == '<none>:<none>':
                    return img_id[:11]
                return tagname

            def parse_image_name(tag, img_id):
                if ':' in tag:
                    result = tag.split(':', 2)[0]
                else:
                    result = tag
                if result == '<none>':
                    result = img_id[:11]
                return result

            ims = self.handler.images()
            if ims and len(ims) > 0 and isinstance(ims[0], dict):
                images = set([])
                tagged = set([])
                for im in ims:
                    repo_tag = '{0}:{1}'.format(im['Repository'], im['Tag'])
                    images.add(parse_image_name(repo_tag, im['Id']))
                    tagged.add(format_tagged(repo_tag, im['Id']))
                self.completer.set_images(images)
                self.completer.set_tagged(tagged)

    def set_fuzzy_match(self, is_fuzzy):
        """
        Setter for fuzzy matching mode
        :param is_fuzzy: boolean
        """
        self.config['main']['fuzzy_match'] = is_fuzzy
        self.completer.set_fuzzy_match(is_fuzzy)

    def get_fuzzy_match(self):
        """
        Getter for fuzzy matching mode
        :return: boolean
        """
        return self.config['main'].as_bool('fuzzy_match')

    def set_long_options(self, is_long):
        """
        Setter for long option names.
        :param is_long: boolean
        """
        self.config['main']['suggest_long_option_names'] = is_long
        self.completer.set_long_options(is_long)

    def get_long_options(self):
        """
        Getter for long option names.
        :return: boolean
        """
        return self.config['main'].as_bool('suggest_long_option_names')

    def refresh_completions(self):
        """
        After processing the command, refresh the lists of
        containers and images as needed
        """
        self.set_completer_options(self.handler.is_refresh_containers,
                                   self.handler.is_refresh_running,
                                   self.handler.is_refresh_images)

    def run_cli(self):
        """
        Run the main loop
        """
        print('Version:', __version__)
        print('Home: http://wharfee.com')

        history = FileHistory(os.path.expanduser('~/.wharfee-history'))
        toolbar_handler = create_toolbar_handler(self.get_long_options, self.get_fuzzy_match)

        layout = create_default_layout(
            message='wharfee> ',
            reserve_space_for_menu=True,
            lexer=CommandLexer,
            get_bottom_toolbar_tokens=toolbar_handler,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
            ]
        )

        cli_buffer = Buffer(
            history=history,
            completer=self.completer,
            complete_while_typing=Always())

        manager = get_key_manager(
            self.set_long_options,
            self.get_long_options,
            self.set_fuzzy_match,
            self.get_fuzzy_match)

        application = Application(
            style=style_factory(self.theme),
            layout=layout,
            buffer=cli_buffer,
            key_bindings_registry=manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION,
            ignore_case=True)

        eventloop = create_eventloop()

        self.dcli = CommandLineInterface(
            application=application,
            eventloop=eventloop)

        while True:
            try:
                document = self.dcli.run()
                self.handler.handle_input(document.text)

                if isinstance(self.handler.output, GeneratorType):
                    output_stream(self.handler.command,
                                  self.handler.output,
                                  self.handler.logs)

                elif self.handler.output is not None:
                    lines = format_data(
                        self.handler.command,
                        self.handler.output)
                    click.echo_via_pager('\n'.join(lines))

                if self.handler.after:
                    for line in self.handler.after():
                        click.echo(line)

                self.refresh_completions()

            except OptionError as ex:
                self.logger.debug('Error: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(ex.msg, fg='red')

            except KeyboardInterrupt:
                # user pressed Ctrl + C
                if self.handler.after:
                    click.echo('')
                    for line in self.handler.after():
                        click.echo(line)

                self.refresh_completions()

            except DockerPermissionException as ex:
                self.logger.debug('Permission exception: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho(ex.message, fg='red')

            except EOFError:
                # exit out of the CLI
                break

            # TODO: uncomment for release
            except Exception as ex:
                self.logger.debug('Exception: %r.', ex)
                self.logger.error("traceback: %r", traceback.format_exc())
                click.secho("{0}".format(ex), fg='red')
                break

        self.revert_less_opts()
        self.write_config_file()
        print('Goodbye!')
Example #30
0
File: main.py Project: d33tah/pgcli
class PGCli(object):
    def __init__(self,
                 force_passwd_prompt=False,
                 never_passwd_prompt=False,
                 pgexecute=None,
                 pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        on_error_modes = {'STOP': ON_ERROR_STOP, 'RESUME': ON_ERROR_RESUME}
        self.on_error = on_error_modes[c['main']['on_error'].upper()]

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(self.change_db,
                                '\\c',
                                '\\c[onnect] database_name',
                                'Change to a new database.',
                                aliases=('use', '\\connect', 'USE'))
        self.pgspecial.register(self.refresh_completions,
                                '\\#',
                                '\\#',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(self.refresh_completions,
                                '\\refresh',
                                '\\refresh',
                                'Refresh auto-completions.',
                                arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.pgexecute.run(query,
                                  self.pgspecial,
                                  on_error=self.on_error)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {
            'CRITICAL': logging.CRITICAL,
            'ERROR': logging.ERROR,
            'WARNING': logging.WARNING,
            'INFO': logging.INFO,
            'DEBUG': logging.DEBUG
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username, uri.port,
                     uri.password)

    def connect(self,
                database='',
                host='',
                user='',
                port='',
                passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt('Password',
                                  hide_input=True,
                                  show_default=False,
                                  type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0])
                        and auto_passwd_prompt):
                    passwd = click.prompt('Password',
                                          hide_input=True,
                                          show_default=False,
                                          type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                        sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql,
                                                   cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        pgexecute = self.pgexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s> ' % pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode, self.completion_refresher.is_refreshing)

        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])
        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = PGBuffer(always_multiline=self.multi_line,
                           completer=self.completer,
                           history=FileHistory(
                               os.path.expanduser(history_file)),
                           complete_while_typing=Always())

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                            eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in pgexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    # Run the query.
                    start = time()
                    res = pgexecute.run(document.text,
                                        self.pgspecial,
                                        on_error=self.on_error)
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            click.secho(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        if self.pgspecial.auto_expand:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(
                            title, cur, headers, status, self.table_format,
                            self.pgspecial.expanded_output, max_width)
                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    reconnect = True
                    if ('server closed the connection'
                            in utf8tounicode(e.args[0])):
                        reconnect = click.prompt(
                            'Connection reset. Reconnect (Y/n)',
                            show_default=False,
                            type=bool,
                            default=True)
                        if reconnect:
                            try:
                                pgexecute.connect()
                                click.secho(
                                    'Reconnected!\nTry the command again.',
                                    fg='green')
                            except OperationalError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    successful = True
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if self.pgspecial.timing_enabled:
                        print('Time: %0.03fs' % total)

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(
                            need_completion_reset(document.text))

                    # Refresh search_path to set default schema.
                    if need_search_path_refresh(document.text):
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(
                                pgexecute.search_path())
                        logger.debug('Search path: %r',
                                     self.completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-SRXF'

        return less_opts

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                 'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #31
0
File: main.py Project: nosun/pgcli
    def run_cli(self):
        pgexecute = self.pgexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        completer = self.completer
        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        def prompt_tokens(cli):
            return [(Token.Prompt,  '%s> ' % pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode)
        layout = create_default_layout(lexer=PostgresLexer,
                                       reserve_space_for_menu=True,
                                       get_prompt_tokens=prompt_tokens,
                                       get_bottom_toolbar_tokens=get_toolbar_tokens,
                                       extra_input_processors=[
                                           # Highlight matching brackets while editing.
                                           ConditionalProcessor(
                                               processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                                               filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
                                       ])
        history_file = self.config['main']['history_file']
        buf = PGBuffer(always_multiline=self.multi_line, completer=completer,
                history=FileHistory(os.path.expanduser(history_file)),
                complete_while_typing=Always())

        application = Application(style=style_factory(self.syntax_style),
                                  layout=layout, buffer=buf,
                                  key_bindings_registry=key_binding_manager.registry,
                                  on_exit=AbortAction.RAISE_EXCEPTION)
        cli = CommandLineInterface(application=application,
                                   eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in pgexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    start = time()
                    # Run the query.
                    res = pgexecute.run(document.text, self.pgspecial)
                    duration = time() - start
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and
                                cur and cur.rowcount > threshold):
                            click.secho('The result set has more than %s rows.'
                                    % threshold, fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        formatted = format_output(title, cur, headers, status,
                                                  self.table_format,
                                                  self.pgspecial.expanded_output)
                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    reconnect = True
                    if ('server closed the connection' in utf8tounicode(e.args[0])):
                        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                show_default=False, type=bool, default=True)
                        if reconnect:
                            try:
                                pgexecute.connect()
                                click.secho('Reconnected!\nTry the command again.', fg='green')
                            except OperationalError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    click.echo_via_pager('\n'.join(output))
                    if self.pgspecial.timing_enabled:
                        print('Command Time: %0.03fs' % duration)
                        print('Format Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_completions()

                # Refresh search_path to set default schema.
                if need_search_path_refresh(document.text):
                    logger.debug('Refreshing search path')
                    completer.set_search_path(pgexecute.search_path())
                    logger.debug('Search path: %r', completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print ('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts
Example #32
0
    def run_cli(self):
        sqlexecute = self.sqlexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.initialize_completions()
        completer = self.completer

        def set_key_bindings(value):
            if value not in ('emacs', 'vi'):
                value = 'emacs'
            self.key_bindings = value

        project_root = os.path.dirname(PACKAGE_ROOT)
        author_file = os.path.join(project_root, 'AUTHORS')
        sponsor_file = os.path.join(project_root, 'SPONSORS')

        key_binding_manager = mycli_bindings(
            get_key_bindings=lambda: self.key_bindings,
            set_key_bindings=set_key_bindings)
        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/mycli')
        print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
        print('Home: http://mycli.net')
        print('Thanks to the contributor -',
              thanks_picker([author_file, sponsor_file]))

        def prompt_tokens(cli):
            return [(Token.Prompt, self.get_prompt(self.prompt_format))]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.key_bindings)
        layout = create_default_layout(
            lexer=MyCliLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            extra_input_processors=[
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])
        buf = CLIBuffer(always_multiline=self.multi_line,
                        completer=completer,
                        history=FileHistory(
                            os.path.expanduser('~/.mycli-history')),
                        complete_while_typing=Always())

        application = Application(
            style=style_factory(self.syntax_style),
            layout=layout,
            buffer=buf,
            key_bindings_registry=key_binding_manager.registry,
            on_exit=AbortAction.RAISE_EXCEPTION)
        cli = CommandLineInterface(application=application,
                                   eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()

                special.set_expanded_output(False)

                # The reason we check here instead of inside the sqlexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the
                # sqlexecute.run() statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                    continue

                destroy = confirm_destructive_query(document.text)
                if destroy is None:
                    pass  # Query was not destructive. Nothing to do here.
                elif destroy is True:
                    self.output('Your call!')
                else:
                    self.output('Wise choice!')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    if self.logfile:
                        self.logfile.write('\n# %s\n' % datetime.now())
                        self.logfile.write(document.text)
                        self.logfile.write('\n')
                    successful = False
                    start = time()
                    res = sqlexecute.run(document.text)
                    duration = time() - start
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            self.output(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                self.output("Aborted!", err=True, fg='red')
                                break
                        output.extend(
                            format_output(title, cur, headers, status,
                                          self.table_format))
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    sqlexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    self.output("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    self.output('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    logger.debug("Exception: %r", e)
                    reconnect = True
                    if (e.args[0] in (2003, 2006, 2013)):
                        reconnect = click.prompt(
                            'Connection reset. Reconnect (Y/n)',
                            show_default=False,
                            type=bool,
                            default=True)
                        if reconnect:
                            logger.debug('Attempting to reconnect.')
                            try:
                                sqlexecute.connect()
                                logger.debug('Reconnected successfully.')
                                self.output(
                                    'Reconnected!\nTry the command again.',
                                    fg='green')
                            except OperationalError as e:
                                logger.debug('Reconnect failed. e: %r', e)
                                self.output(str(e), err=True, fg='red')
                                continue  # If reconnection failed, don't proceed further.
                        else:  # If user chooses not to reconnect, don't proceed further.
                            continue
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        self.output(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                else:
                    self.output_via_pager('\n'.join(output))
                    if special.is_timing_enabled():
                        self.output('Command Time: %0.03fs' % duration)
                        self.output('Format Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_dynamic_completions()

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            self.output('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts
            os.environ['PAGER'] = special.get_original_pager()
Example #33
0
class CLI:
    def __init__(self, host, port, user, password, database, settings, format,
                 format_stdin, multiline, stacktrace):
        self.config = None

        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.database = database
        self.settings = {k: v[0] for k, v in parse_qs(settings).items()}
        self.format = format
        self.format_stdin = format_stdin
        self.multiline = multiline
        self.stacktrace = stacktrace
        self.server_version = None

        self.query_ids = []
        self.client = None
        self.echo = Echo(verbose=True, colors=True)
        self.progress = None

        self.metadata = {}

    def connect(self):
        self.scheme = 'http'
        if '://' in self.host:
            u = urlparse(self.host, allow_fragments=False)
            self.host = u.hostname
            self.port = u.port or self.port
            self.scheme = u.scheme
        self.url = '{scheme}://{host}:{port}/'.format(scheme=self.scheme,
                                                      host=self.host,
                                                      port=self.port)
        self.client = Client(
            self.url,
            self.user,
            self.password,
            self.database,
            self.settings,
            self.stacktrace,
            self.conn_timeout,
            self.conn_timeout_retry,
            self.conn_timeout_retry_delay,
        )

        self.echo.print("Connecting to {host}:{port}".format(host=self.host,
                                                             port=self.port))

        try:
            response = self.client.query('SELECT version();',
                                         fmt='TabSeparated')
        except TimeoutError:
            self.echo.error("Error: Connection timeout.")
            return False
        except ConnectionError:
            self.echo.error("Error: Failed to connect.")
            return False
        except DBException as e:
            self.echo.error("Error:")
            self.echo.error(e.error)

            if self.stacktrace and e.stacktrace:
                self.echo.print("Stack trace:")
                self.echo.print(e.stacktrace)

            return False

        if not response.data.endswith('\n'):
            self.echo.error(
                "Error: Request failed: `SELECT version();` query failed.")
            return False

        version = response.data.strip().split('.')
        self.server_version = (int(version[0]), int(version[1]),
                               int(version[2]))

        self.echo.success(
            "Connected to ClickHouse server v{0}.{1}.{2}.\n".format(
                *self.server_version))
        return True

    def load_config(self):
        self.config = read_config()

        self.multiline = self.config.getboolean('main', 'multiline')
        self.format = self.format or self.config.get('main', 'format')
        self.format_stdin = self.format_stdin or self.config.get(
            'main', 'format_stdin')
        self.show_formatted_query = self.config.getboolean(
            'main', 'show_formatted_query')
        self.highlight = self.config.getboolean('main', 'highlight')
        self.highlight_output = self.config.getboolean('main',
                                                       'highlight_output')
        self.highlight_truecolor = self.config.getboolean(
            'main', 'highlight_truecolor') and os.environ.get('COLORTERM')

        self.conn_timeout = self.config.getfloat('http', 'conn_timeout')
        self.conn_timeout_retry = self.config.getint('http',
                                                     'conn_timeout_retry')
        self.conn_timeout_retry_delay = self.config.getfloat(
            'http', 'conn_timeout_retry_delay')

        self.host = self.host or self.config.get('defaults',
                                                 'host') or '127.0.0.1'
        self.port = self.port or self.config.get('defaults', 'port') or 8123
        self.user = self.user or self.config.get('defaults',
                                                 'user') or 'default'
        self.password = self.password or self.config.get(
            'defaults', 'password') or ''
        self.database = self.database or self.config.get('defaults',
                                                         'db') or 'default'

        config_settings = dict(self.config.items('settings'))
        arg_settings = self.settings
        config_settings.update(arg_settings)
        self.settings = config_settings

        self.echo.colors = self.highlight

    def run(self, query, data):
        self.load_config()

        if data or query is not None:
            self.format = self.format_stdin
            self.echo.verbose = False

        if self.echo.verbose:
            show_version()

        if not self.connect():
            return

        if self.client:
            self.client.settings = self.settings
            self.client.cli_settings = {
                'multiline': self.multiline,
                'format': self.format,
                'format_stdin': self.format_stdin,
                'show_formatted_query': self.show_formatted_query,
                'highlight': self.highlight,
                'highlight_output': self.highlight_output,
            }

        if data and query is None:
            # cat stuff.sql | clickhouse-cli
            # clickhouse-cli stuff.sql
            for subdata in data:
                self.handle_input(subdata.read(),
                                  verbose=False,
                                  refresh_metadata=False)

            return

        if not data and query is not None:
            # clickhouse-cli -q 'SELECT 1'
            return self.handle_query(query, stream=True)

        if data and query is not None:
            # cat stuff.csv | clickhouse-cli -q 'INSERT INTO stuff'
            # clickhouse-cli -q 'INSERT INTO stuff' stuff.csv
            for subdata in data:
                compress = 'gzip' if os.path.splitext(
                    subdata.name)[1] == '.gz' else False

                self.handle_query(query,
                                  data=subdata,
                                  stream=True,
                                  compress=compress)

            return

        layout = create_prompt_layout(
            lexer=PygmentsLexer(CHLexer) if self.highlight else None,
            get_prompt_tokens=get_prompt_tokens,
            get_continuation_tokens=get_continuation_tokens,
            multiline=self.multiline,
        )

        buffer = CLIBuffer(
            client=self.client,
            multiline=self.multiline,
            metadata=self.metadata,
        )

        application = Application(
            layout=layout,
            buffer=buffer,
            style=CHStyle if self.highlight else None,
            key_bindings_registry=KeyBinder.registry,
        )

        eventloop = create_eventloop()

        self.cli = CommandLineInterface(application=application,
                                        eventloop=eventloop)
        self.cli.application.buffer.completer.refresh_metadata()

        try:
            while True:
                try:
                    cli_input = self.cli.run(reset_current_buffer=True)
                    self.handle_input(cli_input.text)
                except KeyboardInterrupt:
                    # Attempt to terminate queries
                    for query_id in self.query_ids:
                        self.client.kill_query(query_id)

                    self.echo.error("\nQuery was terminated.")
                finally:
                    self.query_ids = []
        except EOFError:
            self.echo.success("Bye.")

    def handle_input(self, input_data, verbose=True, refresh_metadata=True):
        force_pager = False
        if input_data.endswith(
                r'\p' if isinstance(input_data, str) else rb'\p'):
            input_data = input_data[:-2]
            force_pager = True

        # FIXME: A dirty dirty hack to make multiple queries (per one paste) work.
        self.query_ids = []
        for query in sqlparse.split(input_data):
            query_id = str(uuid4())
            self.query_ids.append(query_id)
            self.handle_query(query,
                              verbose=verbose,
                              query_id=query_id,
                              force_pager=force_pager)

        if refresh_metadata and input_data:
            self.cli.application.buffer.completer.refresh_metadata()

    def handle_query(self,
                     query,
                     data=None,
                     stream=False,
                     verbose=False,
                     query_id=None,
                     compress=False,
                     **kwargs):
        if query.rstrip(';') == '':
            return

        elif query.lower() in EXIT_COMMANDS:
            raise EOFError

        elif query.lower() in (r'\?', 'help'):
            rows = [
                ['', ''],
                ["clickhouse-cli's custom commands:", ''],
                ['---------------------------------', ''],
                ['USE', "Change the current database."],
                ['SET', "Set an option for the current CLI session."],
                ['QUIT', "Exit clickhouse-cli."],
                ['HELP', "Show this help message."],
                ['', ''],
                ["PostgreSQL-like custom commands:", ''],
                ['--------------------------------', ''],
                [r'\l', "Show databases."],
                [r'\c', "Change the current database."],
                [r'\d, \dt', "Show tables in the current database."],
                [r'\d+', "Show table's schema."],
                [r'\ps', "Show current queries."],
                [r'\kill', "Kill query by its ID."],
                ['', ''],
                ["Query suffixes:", ''],
                ['---------------', ''],
                [r'\g, \G', "Use the Vertical format."],
                [r'\p', "Enable the pager."],
            ]

            for row in rows:
                self.echo.success('{:<8s}'.format(row[0]), nl=False)
                self.echo.info(row[1])
            return

        elif query in (r'\d', r'\dt'):
            query = 'SHOW TABLES'

        elif query.startswith(r'\d+ '):
            query = 'DESCRIBE TABLE ' + query[4:]

        elif query == r'\l':
            query = 'SHOW DATABASES'

        elif query.startswith(r'\c '):
            query = 'USE ' + query[3:]

        elif query.startswith(r'\ps'):
            query = (
                "SELECT query_id, user, address, elapsed, {}, memory_usage "
                "FROM system.processes WHERE query_id != '{}'").format(
                    'read_rows' if self.server_version[2] >= 54115 else
                    'rows_read', query_id)

        elif query.startswith(r'\kill '):
            self.client.kill_query(query[6:])
            return

        response = ''

        self.progress_reset()

        try:
            response = self.client.query(
                query,
                fmt=self.format,
                data=data,
                stream=stream,
                verbose=verbose,
                query_id=query_id,
                compress=compress,
            )
        except TimeoutError:
            self.echo.error("Error: Connection timeout.")
            return
        except ConnectionError:
            self.echo.error("Error: Failed to connect.")
            return
        except DBException as e:
            self.progress_reset()
            self.echo.error("\nQuery:")
            self.echo.error(query)
            self.echo.error("\n\nReceived exception from server:")
            self.echo.error(e.error)

            if self.stacktrace and e.stacktrace:
                self.echo.print("\nStack trace:")
                self.echo.print(e.stacktrace)

            self.echo.print('\nElapsed: {elapsed:.3f} sec.\n'.format(
                elapsed=e.response.elapsed.total_seconds()))

            return

        total_rows, total_bytes = self.progress_reset()

        self.echo.print()

        if stream:
            data = response.iter_lines() if hasattr(
                response, 'iter_lines') else response.data
            for line in data:
                print(line.decode('utf-8', 'ignore'))

        else:
            if response.data != '':
                print_func = print

                if self.config.getboolean('main', 'pager') or kwargs.pop(
                        'force_pager', False):
                    print_func = self.echo.pager

                should_highlight_output = (verbose and self.highlight
                                           and self.highlight_output and
                                           response.format in PRETTY_FORMATS)

                formatter = TerminalFormatter()

                if self.highlight and self.highlight_truecolor:
                    formatter = TerminalTrueColorFormatter(
                        style=CHPygmentsStyle)

                if should_highlight_output:
                    print_func(
                        pygments.highlight(response.data,
                                           CHPrettyFormatLexer(), formatter))
                else:
                    print_func(response.data)

        if response.message != '':
            self.echo.print(response.message)
            self.echo.print()

        self.echo.success('Ok. ', nl=False)

        if response.rows is not None:
            self.echo.print('{rows_count} row{rows_plural} in set.'.format(
                rows_count=response.rows,
                rows_plural='s' if response.rows != 1 else '',
            ),
                            end=' ')

        if self.config.getboolean(
                'main', 'timing') and response.time_elapsed is not None:
            self.echo.print(
                'Elapsed: {elapsed:.3f} sec. Processed: {rows} rows, {bytes} ({avg_rps} rows/s, {avg_bps}/s)'
                .format(
                    elapsed=response.time_elapsed,
                    rows=numberunit_fmt(total_rows),
                    bytes=sizeof_fmt(total_bytes),
                    avg_rps=numberunit_fmt(total_rows /
                                           max(response.time_elapsed, 0.001)),
                    avg_bps=sizeof_fmt(total_bytes /
                                       max(response.time_elapsed, 0.001)),
                ),
                end='')

        self.echo.print('\n')

    def progress_update(self, line):
        if not self.config.getboolean('main', 'timing'):
            return
        # Parse X-ClickHouse-Progress header
        now = datetime.now()
        progress = json.loads(line[23:].decode().strip())
        progress = {
            'timestamp': now,
            'read_rows': int(progress['read_rows']),
            'total_rows': int(progress['total_rows']),
            'read_bytes': int(progress['read_bytes']),
        }
        # Calculate percentage completed and format initial message
        progress['percents'] = int(
            (progress['read_rows'] / progress['total_rows']) *
            100) if progress['total_rows'] > 0 else 0
        message = 'Progress: {} rows, {}'.format(
            numberunit_fmt(progress['read_rows']),
            sizeof_fmt(progress['read_bytes']))
        # Calculate row and byte read velocity
        if self.progress:
            delta = (now - self.progress['timestamp']).total_seconds()
            if delta > 0:
                rps = (progress['read_rows'] -
                       self.progress['read_rows']) / delta
                bps = (progress['read_bytes'] -
                       self.progress['read_bytes']) / delta
                message += ' ({} rows/s, {}/s)'.format(numberunit_fmt(rps),
                                                       sizeof_fmt(bps))
        self.progress = progress
        self.progress_print(message, progress['percents'])

    def progress_reset(self):
        progress = self.progress
        self.progress = None
        clickhouse_cli.helpers.trace_headers_stream = self.progress_update
        # Clear printed progress (if any)
        columns = shutil.get_terminal_size((80, 0)).columns
        sys.stdout.write(u"\u001b[%dD" % columns + " " * columns)
        sys.stdout.flush()
        # Report totals
        if progress:
            return (progress['read_rows'], progress['read_bytes'])
        return (0, 0)

    def progress_print(self, message, percents):
        suffix = '%3d%%' % percents
        columns = shutil.get_terminal_size((80, 0)).columns
        bars_max = columns - (len(message) + len(suffix) + 3)
        bars = int(percents * (bars_max / 100)) if (bars_max > 0) else 0
        message = '{} \033[42m{}\033[0m{} {}'.format(message, " " * bars,
                                                     " " * (bars_max - bars),
                                                     suffix)
        sys.stdout.write(u"\u001b[%dD" % columns + message)
        sys.stdout.flush()
Example #34
0
File: repl.py Project: crate/crash
def loop(cmd, history_file):

    key_binding_manager = KeyBindingManager(
        enable_search=True,
        enable_abort_and_exit_bindings=True,
        enable_system_bindings=True,
        enable_open_in_editor=True
    )
    bind_keys(key_binding_manager.registry)
    layout = create_layout(
        multiline=True,
        lexer=SqlLexer,
        extra_input_processors=[
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ],
        get_bottom_toolbar_tokens=lambda cli: get_toolbar_tokens(cmd),
        get_prompt_tokens=lambda cli: [(Token.Prompt, 'cr> ')]
    )
    application = Application(
        layout=layout,
        buffer=create_buffer(cmd, history_file),
        style=PygmentsStyle.from_defaults(pygments_style_cls=CrateStyle),
        key_bindings_registry=key_binding_manager.registry,
        editing_mode=_get_editing_mode(),
        on_exit=AbortAction.RAISE_EXCEPTION,
        on_abort=AbortAction.RETRY,
    )
    eventloop = create_eventloop()
    output = create_output()
    cli = CommandLineInterface(
        application=application,
        eventloop=eventloop,
        output=output
    )

    def get_num_columns_override():
        return output.get_size().columns
    cmd.get_num_columns = get_num_columns_override

    while True:
        try:
            doc = cli.run(reset_current_buffer=True)
            if doc:
                cmd.process(doc.text)
        except ProgrammingError as e:
            if '401' in e.message:
                username = cmd.username
                password = cmd.password
                cmd.username = input('Username: '******'Bye!')
            return
Example #35
0
class MyCli(object):

    default_prompt = '\\t \\u@\\h:\\d> '
    defaults_suffix = None

    # In order of being loaded. Files lower in list override earlier ones.
    cnf_files = [
        '/etc/my.cnf',
        '/etc/mysql/my.cnf',
        '/usr/local/etc/my.cnf',
        '~/.my.cnf'
    ]

    system_config_files = [
		'/etc/myclirc',
    ]

    default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
    user_config_file = '~/.myclirc'


    def __init__(self, sqlexecute=None, prompt=None,
            logfile=None, defaults_suffix=None, defaults_file=None,
            login_path=None, auto_vertical_output=False):
        self.sqlexecute = sqlexecute
        self.logfile = logfile
        self.defaults_suffix = defaults_suffix
        self.login_path = login_path
        self.auto_vertical_output = auto_vertical_output

        # self.cnf_files is a class variable that stores the list of mysql
        # config files to read in at launch.
        # If defaults_file is specified then override the class variable with
        # defaults_file.
        if defaults_file:
            self.cnf_files = [defaults_file]

        # Load config.
        config_files = ([self.default_config_file] + self.system_config_files +
                        [self.user_config_file])
        c = self.config = read_config_files(config_files)
        self.multi_line = c['main'].as_bool('multi_line')
        self.destructive_warning = c['main'].as_bool('destructive_warning')
        self.key_bindings = c['main']['key_bindings']
        special.set_timing_enabled(c['main'].as_bool('timing'))
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        # Write user config if system config wasn't the last config loaded.
        if c.filename not in self.system_config_files:
            write_default_config(self.default_config_file, self.user_config_file)

        # audit log
        if self.logfile is None and 'audit_log' in c['main']:
            try:
                self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
            except (IOError, OSError) as e:
                self.output('Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red')
                self.logfile = False

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
        self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
                             self.default_prompt

        self.query_history = []

        # Initialize completer.
        smart_completion = c['main'].as_bool('smart_completion')
        self.completer = SQLCompleter(smart_completion)
        self._completer_lock = threading.Lock()

        # Register custom special commands.
        self.register_special_commands()

        # Load .mylogin.cnf if it exists.
        mylogin_cnf_path = get_mylogin_cnf_path()
        if mylogin_cnf_path:
            try:
                mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
                if mylogin_cnf_path and mylogin_cnf:
                    # .mylogin.cnf gets read last, even if defaults_file is specified.
                    self.cnf_files.append(mylogin_cnf)
                elif mylogin_cnf_path and not mylogin_cnf:
                    # There was an error reading the login path file.
                    print('Error: Unable to read login path file.')
            except CryptoError:
                click.secho('Warning: .mylogin.cnf was not read: pycrypto '
                            'module is not available.')

        self.cli = None

    def register_special_commands(self):
        special.register_special_command(self.change_db, 'use',
                '\\u', 'Change to a new database.', aliases=('\\u',))
        special.register_special_command(self.change_db, 'connect',
                '\\r', 'Reconnect to the database. Optional database argument.',
                aliases=('\\r', ), case_sensitive=True)
        special.register_special_command(self.refresh_completions, 'rehash',
                '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
        special.register_special_command(self.change_table_format, 'tableformat',
                '\\T', 'Change Table Type.', aliases=('\\T',), case_sensitive=True)
        special.register_special_command(self.execute_from_file, 'source', '\\. filename',
                              'Execute commands from file.', aliases=('\\.',))
        special.register_special_command(self.change_prompt_format, 'prompt',
                '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)

    def change_table_format(self, arg, **_):
        if not arg in table_formats():
            msg = "Table type %s not yet implemented.  Allowed types:" % arg
            for table_type in table_formats():
                msg += "\n\t%s" % table_type
            yield (None, None, None, msg)
        else:
            self.table_format = arg
            yield (None, None, None, "Changed table Type to %s" % self.table_format)

    def change_db(self, arg, **_):
        if arg is None:
            self.sqlexecute.connect()
        else:
            self.sqlexecute.connect(database=arg)

        yield (None, None, None, 'You are now connected to database "%s" as '
                'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))

    def execute_from_file(self, arg, **_):
        if not arg:
            message = 'Missing required argument, filename.'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(arg), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.sqlexecute.run(query)

    def change_prompt_format(self, arg, **_):
        """
        Change the prompt format.
        """
        if not arg:
            message = 'Missing required argument, format.'
            return [(None, None, None, message)]

        self.prompt_format = self.get_prompt(arg)
        return [(None, None, None, "Changed prompt format to %s" % arg)]

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG
                     }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('mycli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        logging.captureWarnings(True)

        root_logger.debug('Initializing mycli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.username, uri.password, uri.hostname,
                uri.port)

    def read_my_cnf_files(self, files, keys):
        """
        Reads a list of config files and merges them. The last one will win.
        :param files: list of files to read
        :param keys: list of keys to retrieve
        :returns: tuple, with None for missing keys.
        """
        cnf = read_config_files(files)

        sections = ['client']
        if self.login_path and self.login_path != 'client':
            sections.append(self.login_path)

        if self.defaults_suffix:
            sections.extend([sect + self.defaults_suffix for sect in sections])

        def get(key):
            result = None
            for sect in cnf:
                if sect in sections and key in cnf[sect]:
                    result = cnf[sect][key]
            return result

        return dict([(x, get(x)) for x in keys])

    def connect(self, database='', user='', passwd='', host='', port='',
            socket='', charset=''):

        cnf = {'database': None,
               'user': None,
               'password': None,
               'host': None,
               'port': None,
               'socket': None,
               'default-character-set': None}

        cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())

        # Fall back to config values only if user did not specify a value.

        database = database or cnf['database']
        if port or host:
            socket = ''
        else:
            socket = socket or cnf['socket']
        user = user or cnf['user'] or os.getenv('USER')
        host = host or cnf['host'] or 'localhost'
        port = port or cnf['port'] or 3306
        try:
            port = int(port)
        except ValueError as e:
            self.output("Error: Invalid port number: '{0}'.".format(port),
                        err=True, fg='red')
            exit(1)

        passwd = passwd or cnf['password']
        charset = charset or cnf['default-character-set'] or 'utf8'

        # Connect to the database.

        try:
            try:
                sqlexecute = SQLExecute(database, user, passwd, host, port,
                        socket, charset)
            except OperationalError as e:
                if ('Access denied for user' in e.args[1]):
                    passwd = click.prompt('Password', hide_input=True,
                                          show_default=False, type=str)
                    sqlexecute = SQLExecute(database, user, passwd, host, port,
                            socket, charset)
                else:
                    raise e
        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            self.output(str(e), err=True, fg='red')
            exit(1)

        self.sqlexecute = sqlexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                          sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        sqlexecute = self.sqlexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()
        self.set_pager_from_config()

        self.refresh_completions()

        def set_key_bindings(value):
            if value not in ('emacs', 'vi'):
                value = 'emacs'
            self.key_bindings = value

        project_root = os.path.dirname(PACKAGE_ROOT)
        author_file = os.path.join(project_root, 'AUTHORS')
        sponsor_file = os.path.join(project_root, 'SPONSORS')

        key_binding_manager = mycli_bindings(get_key_bindings=lambda: self.key_bindings,
                                             set_key_bindings=set_key_bindings)
        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/mycli')
        print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
        print('Home: http://mycli.net')
        print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))

        def prompt_tokens(cli):
            return [(Token.Prompt, self.get_prompt(self.prompt_format))]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.key_bindings,
                                                        self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(lexer=MyCliLexer,
                                      multiline=True,
                                      get_prompt_tokens=prompt_tokens,
                                      get_bottom_toolbar_tokens=get_toolbar_tokens,
                                      display_completions_in_columns=self.wider_completion_menu,
                                      extra_input_processors=[
                                          ConditionalProcessor(
                                              processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                                              filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
                                      ])
        with self._completer_lock:
            buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer,
                    history=FileHistory(os.path.expanduser('~/.mycli-history')),
                    complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT)

            application = Application(style=style_factory(self.syntax_style, self.cli_style),
                                      layout=layout, buffer=buf,
                                      key_bindings_registry=key_binding_manager.registry,
                                      on_exit=AbortAction.RAISE_EXCEPTION,
                                      on_abort=AbortAction.RETRY,
                                      ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                       eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                special.set_expanded_output(False)

                # The reason we check here instead of inside the sqlexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the
                # sqlexecute.run() statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                    continue
                if self.destructive_warning:
                    destroy = confirm_destructive_query(document.text)
                    if destroy is None:
                        pass  # Query was not destructive. Nothing to do here.
                    elif destroy is True:
                        self.output('Your call!')
                    else:
                        self.output('Wise choice!')
                        continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)

                    if self.logfile:
                        self.logfile.write('\n# %s\n' % datetime.now())
                        self.logfile.write(document.text)
                        self.logfile.write('\n')

                    successful = False
                    start = time()
                    res = sqlexecute.run(document.text)
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if (is_select(status) and
                                cur and cur.rowcount > threshold):
                            self.output('The result set has more than %s rows.'
                                    % threshold, fg='red')
                            if not click.confirm('Do you want to continue?'):
                                self.output("Aborted!", err=True, fg='red')
                                break

                        if self.auto_vertical_output:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(title, cur, headers,
                            status, self.table_format,
                            special.is_expanded_output(), max_width)

                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)
                except UnicodeDecodeError as e:
                    import pymysql
                    if pymysql.VERSION < ('0', '6', '7'):
                        message = ('You are running an older version of pymysql.\n'
                                'Please upgrade to 0.6.7 or above to view binary data.\n'
                                'Try \'pip install -U pymysql\'.')
                        self.output(message)
                    else:
                        raise e
                except KeyboardInterrupt:
                    # Restart connection to the database
                    sqlexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    self.output("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    self.output('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    logger.debug("Exception: %r", e)
                    reconnect = True
                    if (e.args[0] in (2003, 2006, 2013)):
                        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                show_default=False, type=bool, default=True)
                        if reconnect:
                            logger.debug('Attempting to reconnect.')
                            try:
                                sqlexecute.connect()
                                logger.debug('Reconnected successfully.')
                                self.output('Reconnected!\nTry the command again.', fg='green')
                            except OperationalError as e:
                                logger.debug('Reconnect failed. e: %r', e)
                                self.output(str(e), err=True, fg='red')
                                continue  # If reconnection failed, don't proceed further.
                        else:  # If user chooses not to reconnect, don't proceed further.
                            continue
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        self.output(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                else:
                    try:
                        if special.is_pager_enabled():
                            self.output_via_pager('\n'.join(output))
                        else:
                            self.output('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if special.is_timing_enabled():
                        self.output('Time: %0.03fs' % total)

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(
                                reset=need_completion_reset(document.text))
                finally:
                    if self.logfile is False:
                        self.output("Warning: This query was not logged.", err=True, fg='red')
                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            self.output('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts
            os.environ['PAGER'] = special.get_original_pager()

    def output(self, text, **kwargs):
        if self.logfile:
            self.logfile.write(utf8tounicode(text))
            self.logfile.write('\n')
        click.secho(text, **kwargs)

    def output_via_pager(self, text):
        if self.logfile:
            self.logfile.write(text)
            self.logfile.write('\n')
        click.echo_via_pager(text)

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-SRXF'

        return less_opts

    def set_pager_from_config(self):
        cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
        if cnf['pager']:
            special.set_pager(cnf['pager'])
        if cnf['skip-pager']:
            special.disable_pager()

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.sqlexecute,
                                          self._on_completions_refreshed)

        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When mycli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        sqlexecute = self.sqlexecute
        string = string.replace('\\u', sqlexecute.user or '(none)')
        string = string.replace('\\h', sqlexecute.host or '(none)')
        string = string.replace('\\d', sqlexecute.dbname or '(none)')
        string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
        string = string.replace('\\n', "\n")
        return string
Example #36
0
class VCli(object):

    def __init__(self, vexecute=None, vclirc_file=None):
        self.vexecute = vexecute

        from vcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'vclirc')
        write_default_config(default_config, vclirc_file)

        self.vspecial = VSpecial()

        # Load config.
        c = self.config = load_config(vclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.vspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = VCompleter(smart_completion, vspecial=self.vspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):
        self.vspecial.register(self.change_db, '\\c',
                               '\\c[onnect] [DBNAME]',
                               'Connect to a new database',
                               aliases=('use', '\\connect', 'USE'))
        self.vspecial.register(self.refresh_completions, '\\#', '\\#',
                               'Refresh auto-completions', arg_type=NO_QUERY)
        self.vspecial.register(self.refresh_completions, '\\refresh',
                               '\\refresh', 'Refresh auto-completions',
                               arg_type=NO_QUERY)

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.vexecute.connect(database=db)
        else:
            self.vexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.vexecute.dbname, self.vexecute.user), True)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG
                     }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('vcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing vcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        host = uri.hostname or 'localhost'
        user = uri.username or getpass.getuser()
        port = uri.port or 5433
        password = uri.password or ''
        self.connect(database, host, user, port, password)

    def connect(self, database, host, user, port, passwd):
        # Connect to the database
        try:
            self.vexecute = VExecute(database, user, passwd, host, port)
        except errors.DatabaseError as e:  # Connection can fail
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            error_msg = str(e) or type(e).__name__
            click.secho(error_msg, err=True, fg='red')
            exit(1)

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                        sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(
                sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        vexecute = self.vexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        completer = self.completer
        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = vcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        click.secho('Version: %s' % __version__)

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s=> ' % vexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.vi_mode,
                                                        self.completion_refresher.is_refreshing)
        input_processors = [
            # Highlight matching brackets while editing.
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ]
        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=input_processors)
        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = VBuffer(always_multiline=self.multi_line, completer=self.completer,
                          history=FileHistory(os.path.expanduser(history_file)),
                          complete_while_typing=Always())

            application = Application(style=style_factory(self.syntax_style, self.cli_style),
                                      layout=layout, buffer=buf,
                                      key_bindings_registry=key_binding_manager.registry,
                                      on_exit=AbortAction.RAISE_EXCEPTION,
                                      ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                            eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the vexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the vexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in vexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    start = time()
                    # Run the query.
                    res = vexecute.run(document.text, self.vspecial)
                    duration = time() - start

                    file_output = None
                    stdout_output = []

                    total = 0
                    for title, cur, headers, status, force_stdout in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and
                                cur and cur.rowcount > threshold):
                            click.secho('The result set has more than %s rows.'
                                        % threshold, fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        formatted = format_output(title, cur, headers, status,
                                                  self.table_format,
                                                  self.vspecial.expanded_output,
                                                  self.vspecial.aligned,
                                                  self.vspecial.show_header)

                        if self.vspecial.output is not sys.stdout:
                            file_output = self.vspecial.output

                        if force_stdout or not file_output:
                            output = stdout_output
                        else:
                            output = file_output

                        write_output(output, formatted)

                        if hasattr(cur, 'rowcount'):
                            if self.vspecial.show_header:
                                if cur.rowcount == 1:
                                    write_output(output, '(1 row)')
                                elif headers:
                                    write_output(output, '(%d rows)' % cur.rowcount)
                            if document.text.startswith('\\') and cur.rowcount == 0:
                                stdout_output = ['No matching relations found.']
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    vexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except errors.ConnectionError as e:
                    reconnect = True
                    if ('Connection is closed' in utf8tounicode(e.args[0])):
                        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                show_default=False, type=bool, default=True)
                        if reconnect:
                            try:
                                vexecute.connect()
                                click.secho('Reconnected!\nTry the command again.', fg='green')
                            except errors.DatabaseError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    successful = True
                    if stdout_output:
                        output = '\n'.join(stdout_output)
                        try:
                            click.echo_via_pager(output)
                        except KeyboardInterrupt:
                            pass

                    if file_output:
                        try:
                            file_output.flush()
                        except KeyboardInterrupt:
                            pass
                    if self.vspecial.timing_enabled:
                        print('Time: command: %0.03fs, total: %0.03fs' % (duration, total))

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(need_completion_reset(document.text))

                    # Refresh search_path to set default schema.
                    if need_search_path_refresh(document.text):
                        logger.debug('Refreshing search path')
                        with self._completer_lock:
                            self.completer.set_search_path(vexecute.search_path())
                        logger.debug('Search path: %r', self.completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print ('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-RXF'

        return less_opts

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.vexecute, self.vspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                'Auto-completion refresh started in the background.', True)]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #37
0
File: main.py Project: npk/mycli
    def run_cli(self):
        sqlexecute = self.sqlexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.initialize_completions()
        completer = self.completer

        def set_key_bindings(value):
            if value not in ('emacs', 'vi'):
                value = 'emacs'
            self.key_bindings = value

        project_root = os.path.dirname(PACKAGE_ROOT)
        author_file = os.path.join(project_root, 'AUTHORS')
        sponsor_file = os.path.join(project_root, 'SPONSORS')

        key_binding_manager = mycli_bindings(get_key_bindings=lambda: self.key_bindings,
                                             set_key_bindings=set_key_bindings)
        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/mycli')
        print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
        print('Home: http://mycli.net')
        print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))

        def prompt_tokens(cli):
            return [(Token.Prompt, self.get_prompt(self.prompt_format))]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.key_bindings)
        layout = create_default_layout(lexer=MyCliLexer,
                                       reserve_space_for_menu=True,
                                       get_prompt_tokens=prompt_tokens,
                                       get_bottom_toolbar_tokens=get_toolbar_tokens,
                                       extra_input_processors=[
                                           ConditionalProcessor(
                                               processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                                               filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
                                       ])
        buf = CLIBuffer(always_multiline=self.multi_line, completer=completer,
                history=FileHistory(os.path.expanduser('~/.mycli-history')),
                complete_while_typing=Always())

        application = Application(style=style_factory(self.syntax_style),
                                  layout=layout, buffer=buf,
                                  key_bindings_registry=key_binding_manager.registry,
                                  on_exit=AbortAction.RAISE_EXCEPTION)
        cli = CommandLineInterface(application=application, eventloop=create_eventloop())

        try:
            while True:
                document = cli.run()

                special.set_expanded_output(False)

                # The reason we check here instead of inside the sqlexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the
                # sqlexecute.run() statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                    continue

                destroy = confirm_destructive_query(document.text)
                if destroy is None:
                    pass  # Query was not destructive. Nothing to do here.
                elif destroy is True:
                    self.output('Your call!')
                else:
                    self.output('Wise choice!')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    if self.logfile:
                        self.logfile.write('\n# %s\n' % datetime.now())
                        self.logfile.write(document.text)
                        self.logfile.write('\n')
                    successful = False
                    start = time()
                    res = sqlexecute.run(document.text)
                    duration = time() - start
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and
                                cur and cur.rowcount > threshold):
                            self.output('The result set has more than %s rows.'
                                    % threshold, fg='red')
                            if not click.confirm('Do you want to continue?'):
                                self.output("Aborted!", err=True, fg='red')
                                break
                        output.extend(format_output(title, cur, headers,
                            status, self.table_format))
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)
                except KeyboardInterrupt:
                    # Restart connection to the database
                    sqlexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    self.output("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    self.output('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    logger.debug("Exception: %r", e)
                    reconnect = True
                    if (e.args[0] in (2003, 2006, 2013)):
                        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                show_default=False, type=bool, default=True)
                        if reconnect:
                            logger.debug('Attempting to reconnect.')
                            try:
                                sqlexecute.connect()
                                logger.debug('Reconnected successfully.')
                                self.output('Reconnected!\nTry the command again.', fg='green')
                            except OperationalError as e:
                                logger.debug('Reconnect failed. e: %r', e)
                                self.output(str(e), err=True, fg='red')
                                continue  # If reconnection failed, don't proceed further.
                        else:  # If user chooses not to reconnect, don't proceed further.
                            continue
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        self.output(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                else:
                    self.output_via_pager('\n'.join(output))
                    if special.is_timing_enabled():
                        self.output('Command Time: %0.03fs' % duration)
                        self.output('Format Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_dynamic_completions()

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            self.output('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts
            os.environ['PAGER'] = special.get_original_pager()
Example #38
0
File: main.py Project: w4ngyi/pgcli
class PGCli(object):
    def __init__(self,
                 force_passwd_prompt=False,
                 never_passwd_prompt=False,
                 pgexecute=None,
                 pgclirc_file=None):

        self.force_passwd_prompt = force_passwd_prompt
        self.never_passwd_prompt = never_passwd_prompt
        self.pgexecute = pgexecute

        from pgcli import __file__ as package_root
        package_root = os.path.dirname(package_root)

        default_config = os.path.join(package_root, 'pgclirc')
        write_default_config(default_config, pgclirc_file)

        self.pgspecial = PGSpecial()

        # Load config.
        c = self.config = load_config(pgclirc_file, default_config)
        self.multi_line = c['main'].as_bool('multi_line')
        self.vi_mode = c['main'].as_bool('vi')
        self.pgspecial.timing_enabled = c['main'].as_bool('timing')
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        self.query_history = []

        # Initialize completer
        smart_completion = c['main'].as_bool('smart_completion')
        completer = PGCompleter(smart_completion, pgspecial=self.pgspecial)
        self.completer = completer
        self._completer_lock = threading.Lock()
        self.register_special_commands()

        self.cli = None

    def register_special_commands(self):

        self.pgspecial.register(
            self.change_db,
            '\\c',
            '\\c[onnect] database_name',
            'Change to a new database.',
            aliases=('use', '\\connect', 'USE'))
        self.pgspecial.register(
            self.refresh_completions,
            '\\#',
            '\\#',
            'Refresh auto-completions.',
            arg_type=NO_QUERY)
        self.pgspecial.register(
            self.refresh_completions,
            '\\refresh',
            '\\refresh',
            'Refresh auto-completions.',
            arg_type=NO_QUERY)
        self.pgspecial.register(self.execute_from_file, '\\i', '\\i filename',
                                'Execute commands from file.')

    def change_db(self, pattern, **_):
        if pattern:
            db = pattern[1:-1] if pattern[0] == pattern[-1] == '"' else pattern
            self.pgexecute.connect(database=db)
        else:
            self.pgexecute.connect()

        yield (None, None, None, 'You are now connected to database "%s" as '
               'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user))

    def execute_from_file(self, pattern, **_):
        if not pattern:
            message = '\\i: missing required argument'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(pattern), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.pgexecute.run(query, self.pgspecial)

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {
            'CRITICAL': logging.CRITICAL,
            'ERROR': logging.ERROR,
            'WARNING': logging.WARNING,
            'INFO': logging.INFO,
            'DEBUG': logging.DEBUG
        }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('pgcli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        root_logger.debug('Initializing pgcli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_dsn(self, dsn):
        self.connect(dsn=dsn)

    def connect_uri(self, uri):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.hostname, uri.username, uri.port,
                     uri.password)

    def connect(self,
                database='',
                host='',
                user='',
                port='',
                passwd='',
                dsn=''):
        # Connect to the database.

        if not user:
            user = getuser()

        if not database:
            database = user

        # If password prompt is not forced but no password is provided, try
        # getting it from environment variable.
        if not self.force_passwd_prompt and not passwd:
            passwd = os.environ.get('PGPASSWORD', '')

        # Prompt for a password immediately if requested via the -W flag. This
        # avoids wasting time trying to connect to the database and catching a
        # no-password exception.
        # If we successfully parsed a password from a URI, there's no need to
        # prompt for it, even with the -W flag
        if self.force_passwd_prompt and not passwd:
            passwd = click.prompt(
                'Password', hide_input=True, show_default=False, type=str)

        # Prompt for a password after 1st attempt to connect without a password
        # fails. Don't prompt if the -w flag is supplied
        auto_passwd_prompt = not passwd and not self.never_passwd_prompt

        # Attempt to connect to the database.
        # Note that passwd may be empty on the first attempt. If connection
        # fails because of a missing password, but we're allowed to prompt for
        # a password (no -w flag), prompt for a passwd and try again.
        try:
            try:
                pgexecute = PGExecute(database, user, passwd, host, port, dsn)
            except OperationalError as e:
                if ('no password supplied' in utf8tounicode(e.args[0])
                        and auto_passwd_prompt):
                    passwd = click.prompt(
                        'Password',
                        hide_input=True,
                        show_default=False,
                        type=str)
                    pgexecute = PGExecute(database, user, passwd, host, port,
                                          dsn)
                else:
                    raise e

        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            click.secho(str(e), err=True, fg='red')
            exit(1)

        self.pgexecute = pgexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(
                filename, sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(
                sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        pgexecute = self.pgexecute
        logger = self.logger
        original_less_opts = self.adjust_less_opts()

        self.refresh_completions()

        def set_vi_mode(value):
            self.vi_mode = value

        key_binding_manager = pgcli_bindings(
            get_vi_mode_enabled=lambda: self.vi_mode,
            set_vi_mode_enabled=set_vi_mode)

        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/pgcli')
        print('Mail: https://groups.google.com/forum/#!forum/pgcli')
        print('Home: http://pgcli.com')

        def prompt_tokens(cli):
            return [(Token.Prompt, '%s> ' % pgexecute.dbname)]

        get_toolbar_tokens = create_toolbar_tokens_func(
            lambda: self.vi_mode,
            lambda: self.completion_refresher.is_refreshing())

        layout = create_default_layout(
            lexer=PostgresLexer,
            reserve_space_for_menu=True,
            get_prompt_tokens=prompt_tokens,
            get_bottom_toolbar_tokens=get_toolbar_tokens,
            display_completions_in_columns=self.wider_completion_menu,
            multiline=True,
            extra_input_processors=[
                # Highlight matching brackets while editing.
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
            ])
        history_file = self.config['main']['history_file']
        with self._completer_lock:
            buf = PGBuffer(
                always_multiline=self.multi_line,
                completer=self.completer,
                history=FileHistory(os.path.expanduser(history_file)),
                complete_while_typing=Always())

            application = Application(
                style=style_factory(self.syntax_style, self.cli_style),
                layout=layout,
                buffer=buf,
                key_bindings_registry=key_binding_manager.registry,
                on_exit=AbortAction.RAISE_EXCEPTION,
                ignore_case=True)
            self.cli = CommandLineInterface(
                application=application, eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                # The reason we check here instead of inside the pgexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the pgexecute.run()
                # statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                    continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)
                    successful = False
                    # Initialized to [] because res might never get initialized
                    # if an exception occurs in pgexecute.run(). Which causes
                    # finally clause to fail.
                    res = []
                    start = time()
                    # Run the query.
                    res = pgexecute.run(document.text, self.pgspecial)
                    duration = time() - start
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        start = time()
                        threshold = 1000
                        if (is_select(status) and cur
                                and cur.rowcount > threshold):
                            click.secho(
                                'The result set has more than %s rows.' %
                                threshold,
                                fg='red')
                            if not click.confirm('Do you want to continue?'):
                                click.secho("Aborted!", err=True, fg='red')
                                break

                        if self.pgspecial.auto_expand:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(
                            title, cur, headers, status, self.table_format,
                            self.pgspecial.expanded_output, max_width)
                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)

                except KeyboardInterrupt:
                    # Restart connection to the database
                    pgexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    click.secho("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    click.secho('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    reconnect = True
                    if ('server closed the connection' in utf8tounicode(
                            e.args[0])):
                        reconnect = click.prompt(
                            'Connection reset. Reconnect (Y/n)',
                            show_default=False,
                            type=bool,
                            default=True)
                        if reconnect:
                            try:
                                pgexecute.connect()
                                click.secho(
                                    'Reconnected!\nTry the command again.',
                                    fg='green')
                            except OperationalError as e:
                                click.secho(str(e), err=True, fg='red')
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        click.secho(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    click.secho(str(e), err=True, fg='red')
                else:
                    try:
                        click.echo_via_pager('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if self.pgspecial.timing_enabled:
                        print('Command Time: %0.03fs' % duration)
                        print('Format Time: %0.03fs' % total)

                # Refresh the table names and column names if necessary.
                if need_completion_refresh(document.text):
                    self.refresh_completions()

                # Refresh search_path to set default schema.
                if need_search_path_refresh(document.text):
                    logger.debug('Refreshing search path')
                    with self._completer_lock:
                        self.completer.set_search_path(pgexecute.search_path())
                    logger.debug('Search path: %r', self.completer.search_path)

                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            print('Goodbye!')
        finally:  # Reset the less opts back to original.
            logger.debug('Restoring env var LESS to %r.', original_less_opts)
            os.environ['LESS'] = original_less_opts

    def adjust_less_opts(self):
        less_opts = os.environ.get('LESS', '')
        self.logger.debug('Original value for LESS env var: %r', less_opts)
        os.environ['LESS'] = '-SRXF'

        return less_opts

    def refresh_completions(self):
        self.completion_refresher.refresh(self.pgexecute, self.pgspecial,
                                          self._on_completions_refreshed)
        return [(None, None, None,
                 'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When pgcli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)
Example #39
0
File: main.py Project: Hsuing/mycli
class MyCli(object):

    default_prompt = '\\t \\u@\\h:\\d> '
    defaults_suffix = None

    # In order of being loaded. Files lower in list override earlier ones.
    cnf_files = [
        '/etc/my.cnf',
        '/etc/mysql/my.cnf',
        '/usr/local/etc/my.cnf',
        '~/.my.cnf'
    ]

    system_config_files = [
		'/etc/myclirc',
    ]

    default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
    user_config_file = '~/.myclirc'


    def __init__(self, sqlexecute=None, prompt=None,
            logfile=None, defaults_suffix=None, defaults_file=None,
            login_path=None, auto_vertical_output=False):
        self.sqlexecute = sqlexecute
        self.logfile = logfile
        self.defaults_suffix = defaults_suffix
        self.login_path = login_path
        self.auto_vertical_output = auto_vertical_output

        # self.cnf_files is a class variable that stores the list of mysql
        # config files to read in at launch.
        # If defaults_file is specified then override the class variable with
        # defaults_file.
        if defaults_file:
            self.cnf_files = [defaults_file]

        # Load config.
        config_files = ([self.default_config_file] + self.system_config_files +
                        [self.user_config_file])
        c = self.config = read_config_files(config_files)
        self.multi_line = c['main'].as_bool('multi_line')
        self.destructive_warning = c['main'].as_bool('destructive_warning')
        self.key_bindings = c['main']['key_bindings']
        special.set_timing_enabled(c['main'].as_bool('timing'))
        self.table_format = c['main']['table_format']
        self.syntax_style = c['main']['syntax_style']
        self.cli_style = c['colors']
        self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')

        # Write user config if system config wasn't the last config loaded.
        if c.filename not in self.system_config_files:
            write_default_config(self.default_config_file, self.user_config_file)

        # audit log
        if self.logfile is None and 'audit_log' in c['main']:
            try:
                self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
            except (IOError, OSError) as e:
                self.output('Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red')
                self.logfile = False

        self.completion_refresher = CompletionRefresher()

        self.logger = logging.getLogger(__name__)
        self.initialize_logging()

        prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
        self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
                             self.default_prompt

        self.query_history = []

        # Initialize completer.
        smart_completion = c['main'].as_bool('smart_completion')
        self.completer = SQLCompleter(smart_completion)
        self._completer_lock = threading.Lock()

        # Register custom special commands.
        self.register_special_commands()

        # Load .mylogin.cnf if it exists.
        mylogin_cnf_path = get_mylogin_cnf_path()
        if mylogin_cnf_path:
            try:
                mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
                if mylogin_cnf_path and mylogin_cnf:
                    # .mylogin.cnf gets read last, even if defaults_file is specified.
                    self.cnf_files.append(mylogin_cnf)
                elif mylogin_cnf_path and not mylogin_cnf:
                    # There was an error reading the login path file.
                    print('Error: Unable to read login path file.')
            except CryptoError:
                click.secho('Warning: .mylogin.cnf was not read: pycrypto '
                            'module is not available.')

        self.cli = None

    def register_special_commands(self):
        special.register_special_command(self.change_db, 'use',
                '\\u', 'Change to a new database.', aliases=('\\u',))
        special.register_special_command(self.change_db, 'connect',
                '\\r', 'Reconnect to the database. Optional database argument.',
                aliases=('\\r', ), case_sensitive=True)
        special.register_special_command(self.refresh_completions, 'rehash',
                '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
        special.register_special_command(self.change_table_format, 'tableformat',
                '\\T', 'Change Table Type.', aliases=('\\T',), case_sensitive=True)
        special.register_special_command(self.execute_from_file, 'source', '\\. filename',
                              'Execute commands from file.', aliases=('\\.',))
        special.register_special_command(self.change_prompt_format, 'prompt',
                '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)

    def change_table_format(self, arg, **_):
        if not arg in table_formats():
            msg = "Table type %s not yet implemented.  Allowed types:" % arg
            for table_type in table_formats():
                msg += "\n\t%s" % table_type
            yield (None, None, None, msg)
        else:
            self.table_format = arg
            yield (None, None, None, "Changed table Type to %s" % self.table_format)

    def change_db(self, arg, **_):
        if arg is None:
            self.sqlexecute.connect()
        else:
            self.sqlexecute.connect(database=arg)

        yield (None, None, None, 'You are now connected to database "%s" as '
                'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))

    def execute_from_file(self, arg, **_):
        if not arg:
            message = 'Missing required argument, filename.'
            return [(None, None, None, message)]
        try:
            with open(os.path.expanduser(arg), encoding='utf-8') as f:
                query = f.read()
        except IOError as e:
            return [(None, None, None, str(e))]

        return self.sqlexecute.run(query)

    def change_prompt_format(self, arg, **_):
        """
        Change the prompt format.
        """
        if not arg:
            message = 'Missing required argument, format.'
            return [(None, None, None, message)]

        self.prompt_format = self.get_prompt(arg)
        return [(None, None, None, "Changed prompt format to %s" % arg)]

    def initialize_logging(self):

        log_file = self.config['main']['log_file']
        log_level = self.config['main']['log_level']

        level_map = {'CRITICAL': logging.CRITICAL,
                     'ERROR': logging.ERROR,
                     'WARNING': logging.WARNING,
                     'INFO': logging.INFO,
                     'DEBUG': logging.DEBUG
                     }

        handler = logging.FileHandler(os.path.expanduser(log_file))

        formatter = logging.Formatter(
            '%(asctime)s (%(process)d/%(threadName)s) '
            '%(name)s %(levelname)s - %(message)s')

        handler.setFormatter(formatter)

        root_logger = logging.getLogger('mycli')
        root_logger.addHandler(handler)
        root_logger.setLevel(level_map[log_level.upper()])

        logging.captureWarnings(True)

        root_logger.debug('Initializing mycli logging.')
        root_logger.debug('Log file %r.', log_file)

    def connect_uri(self, uri, local_infile=None):
        uri = urlparse(uri)
        database = uri.path[1:]  # ignore the leading fwd slash
        self.connect(database, uri.username, uri.password, uri.hostname,
                uri.port, local_infile=local_infile)

    def read_my_cnf_files(self, files, keys):
        """
        Reads a list of config files and merges them. The last one will win.
        :param files: list of files to read
        :param keys: list of keys to retrieve
        :returns: tuple, with None for missing keys.
        """
        cnf = read_config_files(files)

        sections = ['client']
        if self.login_path and self.login_path != 'client':
            sections.append(self.login_path)

        if self.defaults_suffix:
            sections.extend([sect + self.defaults_suffix for sect in sections])

        def get(key):
            result = None
            for sect in cnf:
                if sect in sections and key in cnf[sect]:
                    result = cnf[sect][key]
            return result

        return dict([(x, get(x)) for x in keys])

    def connect(self, database='', user='', passwd='', host='', port='',
            socket='', charset='', local_infile=''):

        cnf = {'database': None,
               'user': None,
               'password': None,
               'host': None,
               'port': None,
               'socket': None,
               'default-character-set': None,
               'local-infile': None,
               'loose-local-infile': None}

        cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())

        # Fall back to config values only if user did not specify a value.

        database = database or cnf['database']
        if port or host:
            socket = ''
        else:
            socket = socket or cnf['socket']
        user = user or cnf['user'] or os.getenv('USER')
        host = host or cnf['host'] or 'localhost'
        port = port or cnf['port'] or 3306
        try:
            port = int(port)
        except ValueError as e:
            self.output("Error: Invalid port number: '{0}'.".format(port),
                        err=True, fg='red')
            exit(1)

        passwd = passwd or cnf['password']
        charset = charset or cnf['default-character-set'] or 'utf8'

        # Favor whichever local_infile option is set.
        for local_infile_option in (local_infile, cnf['local-infile'], 
                                    cnf['loose-local-infile'], False):
            try:
                local_infile = str_to_bool(local_infile_option)
                break
            except (TypeError, ValueError):
                pass

        # Connect to the database.

        try:
            try:
                sqlexecute = SQLExecute(database, user, passwd, host, port,
                        socket, charset, local_infile)
            except OperationalError as e:
                if ('Access denied for user' in e.args[1]):
                    passwd = click.prompt('Password', hide_input=True,
                                          show_default=False, type=str)
                    sqlexecute = SQLExecute(database, user, passwd, host, port,
                            socket, charset, local_infile)
                else:
                    raise e
        except Exception as e:  # Connecting to a database could fail.
            self.logger.debug('Database connection failed: %r.', e)
            self.logger.error("traceback: %r", traceback.format_exc())
            self.output(str(e), err=True, fg='red')
            exit(1)

        self.sqlexecute = sqlexecute

    def handle_editor_command(self, cli, document):
        """
        Editor command is any query that is prefixed or suffixed
        by a '\e'. The reason for a while loop is because a user
        might edit a query multiple times.
        For eg:
        "select * from \e"<enter> to edit it in vim, then come
        back to the prompt with the edited query "select * from
        blah where q = 'abc'\e" to edit it again.
        :param cli: CommandLineInterface
        :param document: Document
        :return: Document
        """
        while special.editor_command(document.text):
            filename = special.get_filename(document.text)
            sql, message = special.open_external_editor(filename,
                                                          sql=document.text)
            if message:
                # Something went wrong. Raise an exception and bail.
                raise RuntimeError(message)
            cli.current_buffer.document = Document(sql, cursor_position=len(sql))
            document = cli.run(False)
            continue
        return document

    def run_cli(self):
        sqlexecute = self.sqlexecute
        logger = self.logger
        self.configure_pager()

        self.refresh_completions()

        def set_key_bindings(value):
            if value not in ('emacs', 'vi'):
                value = 'emacs'
            self.key_bindings = value

        project_root = os.path.dirname(PACKAGE_ROOT)
        author_file = os.path.join(project_root, 'AUTHORS')
        sponsor_file = os.path.join(project_root, 'SPONSORS')

        key_binding_manager = mycli_bindings(get_key_bindings=lambda: self.key_bindings,
                                             set_key_bindings=set_key_bindings)
        print('Version:', __version__)
        print('Chat: https://gitter.im/dbcli/mycli')
        print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
        print('Home: http://mycli.net')
        print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))

        def prompt_tokens(cli):
            return [(Token.Prompt, self.get_prompt(self.prompt_format))]

        get_toolbar_tokens = create_toolbar_tokens_func(lambda: self.key_bindings,
                                                        self.completion_refresher.is_refreshing)

        layout = create_prompt_layout(lexer=MyCliLexer,
                                      multiline=True,
                                      get_prompt_tokens=prompt_tokens,
                                      get_bottom_toolbar_tokens=get_toolbar_tokens,
                                      display_completions_in_columns=self.wider_completion_menu,
                                      extra_input_processors=[
                                          ConditionalProcessor(
                                              processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                                              filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()),
                                      ])
        with self._completer_lock:
            buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer,
                    history=FileHistory(os.path.expanduser('~/.mycli-history')),
                    complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT)

            application = Application(style=style_factory(self.syntax_style, self.cli_style),
                                      layout=layout, buffer=buf,
                                      key_bindings_registry=key_binding_manager.registry,
                                      on_exit=AbortAction.RAISE_EXCEPTION,
                                      on_abort=AbortAction.RETRY,
                                      ignore_case=True)
            self.cli = CommandLineInterface(application=application,
                                       eventloop=create_eventloop())

        try:
            while True:
                document = self.cli.run()

                special.set_expanded_output(False)

                # The reason we check here instead of inside the sqlexecute is
                # because we want to raise the Exit exception which will be
                # caught by the try/except block that wraps the
                # sqlexecute.run() statement.
                if quit_command(document.text):
                    raise EOFError

                try:
                    document = self.handle_editor_command(self.cli, document)
                except RuntimeError as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                    continue
                if self.destructive_warning:
                    destroy = confirm_destructive_query(document.text)
                    if destroy is None:
                        pass  # Query was not destructive. Nothing to do here.
                    elif destroy is True:
                        self.output('Your call!')
                    else:
                        self.output('Wise choice!')
                        continue

                # Keep track of whether or not the query is mutating. In case
                # of a multi-statement query, the overall query is considered
                # mutating if any one of the component statements is mutating
                mutating = False

                try:
                    logger.debug('sql: %r', document.text)

                    if self.logfile:
                        self.logfile.write('\n# %s\n' % datetime.now())
                        self.logfile.write(document.text)
                        self.logfile.write('\n')

                    successful = False
                    start = time()
                    res = sqlexecute.run(document.text)
                    successful = True
                    output = []
                    total = 0
                    for title, cur, headers, status in res:
                        logger.debug("headers: %r", headers)
                        logger.debug("rows: %r", cur)
                        logger.debug("status: %r", status)
                        threshold = 1000
                        if (is_select(status) and
                                cur and cur.rowcount > threshold):
                            self.output('The result set has more than %s rows.'
                                    % threshold, fg='red')
                            if not click.confirm('Do you want to continue?'):
                                self.output("Aborted!", err=True, fg='red')
                                break

                        if self.auto_vertical_output:
                            max_width = self.cli.output.get_size().columns
                        else:
                            max_width = None

                        formatted = format_output(title, cur, headers,
                            status, self.table_format,
                            special.is_expanded_output(), max_width)

                        output.extend(formatted)
                        end = time()
                        total += end - start
                        mutating = mutating or is_mutating(status)
                except UnicodeDecodeError as e:
                    import pymysql
                    if pymysql.VERSION < ('0', '6', '7'):
                        message = ('You are running an older version of pymysql.\n'
                                'Please upgrade to 0.6.7 or above to view binary data.\n'
                                'Try \'pip install -U pymysql\'.')
                        self.output(message)
                    else:
                        raise e
                except KeyboardInterrupt:
                    # Restart connection to the database
                    sqlexecute.connect()
                    logger.debug("cancelled query, sql: %r", document.text)
                    self.output("cancelled query", err=True, fg='red')
                except NotImplementedError:
                    self.output('Not Yet Implemented.', fg="yellow")
                except OperationalError as e:
                    logger.debug("Exception: %r", e)
                    reconnect = True
                    if (e.args[0] in (2003, 2006, 2013)):
                        reconnect = click.prompt('Connection reset. Reconnect (Y/n)',
                                show_default=False, type=bool, default=True)
                        if reconnect:
                            logger.debug('Attempting to reconnect.')
                            try:
                                sqlexecute.connect()
                                logger.debug('Reconnected successfully.')
                                self.output('Reconnected!\nTry the command again.', fg='green')
                            except OperationalError as e:
                                logger.debug('Reconnect failed. e: %r', e)
                                self.output(str(e), err=True, fg='red')
                                continue  # If reconnection failed, don't proceed further.
                        else:  # If user chooses not to reconnect, don't proceed further.
                            continue
                    else:
                        logger.error("sql: %r, error: %r", document.text, e)
                        logger.error("traceback: %r", traceback.format_exc())
                        self.output(str(e), err=True, fg='red')
                except Exception as e:
                    logger.error("sql: %r, error: %r", document.text, e)
                    logger.error("traceback: %r", traceback.format_exc())
                    self.output(str(e), err=True, fg='red')
                else:
                    try:
                        if special.is_pager_enabled():
                            self.output_via_pager('\n'.join(output))
                        else:
                            self.output('\n'.join(output))
                    except KeyboardInterrupt:
                        pass
                    if special.is_timing_enabled():
                        self.output('Time: %0.03fs' % total)

                    # Refresh the table names and column names if necessary.
                    if need_completion_refresh(document.text):
                        self.refresh_completions(
                                reset=need_completion_reset(document.text))
                finally:
                    if self.logfile is False:
                        self.output("Warning: This query was not logged.", err=True, fg='red')
                query = Query(document.text, successful, mutating)
                self.query_history.append(query)

        except EOFError:
            self.output('Goodbye!')

    def output(self, text, **kwargs):
        if self.logfile:
            self.logfile.write(utf8tounicode(text))
            self.logfile.write('\n')
        click.secho(text, **kwargs)

    def output_via_pager(self, text):
        if self.logfile:
            self.logfile.write(text)
            self.logfile.write('\n')
        click.echo_via_pager(text)

    def configure_pager(self):
        # Provide sane defaults for less.
        os.environ['LESS'] = '-SRXF'

        cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
        if cnf['pager']:
            special.set_pager(cnf['pager'])
        if cnf['skip-pager']:
            special.disable_pager()

    def refresh_completions(self, reset=False):
        if reset:
            with self._completer_lock:
                self.completer.reset_completions()
        self.completion_refresher.refresh(self.sqlexecute,
                                          self._on_completions_refreshed)

        return [(None, None, None,
                'Auto-completion refresh started in the background.')]

    def _on_completions_refreshed(self, new_completer):
        self._swap_completer_objects(new_completer)

        if self.cli:
            # After refreshing, redraw the CLI to clear the statusbar
            # "Refreshing completions..." indicator
            self.cli.request_redraw()

    def _swap_completer_objects(self, new_completer):
        """Swap the completer object in cli with the newly created completer.
        """
        with self._completer_lock:
            self.completer = new_completer
            # When mycli is first launched we call refresh_completions before
            # instantiating the cli object. So it is necessary to check if cli
            # exists before trying the replace the completer object in cli.
            if self.cli:
                self.cli.current_buffer.completer = new_completer

    def get_completions(self, text, cursor_positition):
        with self._completer_lock:
            return self.completer.get_completions(
                Document(text=text, cursor_position=cursor_positition), None)

    def get_prompt(self, string):
        sqlexecute = self.sqlexecute
        string = string.replace('\\u', sqlexecute.user or '(none)')
        string = string.replace('\\h', sqlexecute.host or '(none)')
        string = string.replace('\\d', sqlexecute.dbname or '(none)')
        string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
        string = string.replace('\\n', "\n")
        return string
Example #40
0
def loop(cmd, history_file):
    from prompt_toolkit import CommandLineInterface, AbortAction, Application
    from prompt_toolkit.interface import AcceptAction
    from prompt_toolkit.enums import DEFAULT_BUFFER
    from prompt_toolkit.layout.processors import (
        HighlightMatchingBracketProcessor,
        ConditionalProcessor
    )
    from prompt_toolkit.key_binding.manager import KeyBindingManager
    from prompt_toolkit.shortcuts import (create_default_layout,
                                          create_default_output,
                                          create_eventloop)

    key_binding_manager = KeyBindingManager(
        enable_search=True,
        enable_abort_and_exit_bindings=True,
        enable_vi_mode=Condition(lambda cli: _enable_vi_mode()))

    layout = create_default_layout(
        message=u'cr> ',
        multiline=True,
        lexer=SqlLexer,
        extra_input_processors=[
            ConditionalProcessor(
                processor=HighlightMatchingBracketProcessor(chars='[](){}'),
                filter=HasFocus(DEFAULT_BUFFER) & ~IsDone())
        ]
    )
    cli_buffer = CrashBuffer(
        history=TruncatedFileHistory(history_file, max_length=MAX_HISTORY_LENGTH),
        accept_action=AcceptAction.RETURN_DOCUMENT,
        completer=SQLCompleter(cmd.connection, cmd.lines),
        complete_while_typing=Always()
    )
    application = Application(
        layout=layout,
        style=CrateStyle,
        buffer=cli_buffer,
        key_bindings_registry=key_binding_manager.registry,
        on_exit=AbortAction.RAISE_EXCEPTION,
        on_abort=AbortAction.RETRY,
    )
    eventloop = create_eventloop()
    output = create_default_output()
    cli = CommandLineInterface(
        application=application,
        eventloop=eventloop,
        output=output
    )

    def get_num_columns_override():
        return output.get_size().columns
    cmd.get_num_columns = get_num_columns_override

    while True:
        try:
            doc = cli.run()
            if doc:
                cmd.process(doc.text)
        except EOFError:
            cmd.logger.warn(u'Bye!')
            return