Пример #1
0
def ssh_con_fabric(test_vars):
    """Create an SSH connection to the controller."""
    log = logging.getLogger("ssh_con_fabric")

    # SSH connection/client to the public IP.
    pub_client = Connection(test_vars["public_ip"],
                            user=test_vars["controller_user"],
                            connect_kwargs={
                                "key_filename": test_vars["ssh_priv_key"],
                            })

    # If the controller's IP is not the same as the public IP, then we are
    # using a jumpbox to get into the VNET containing the controller. In that
    # case, create an SSH tunnel before connecting to the controller.
    msg_con = "SSH connection to controller ({})".format(
        test_vars["controller_ip"])
    if test_vars["public_ip"] != test_vars["controller_ip"]:
        for port_attempt in range(1, 11):
            tunnel_local_port = get_unused_local_port()
            tunnel_remote_port = 22

            msg_con += " via jumpbox ({0}), local port {1}".format(
                test_vars["public_ip"], tunnel_local_port)

            log.debug("Opening {}".format(msg_con))
            with pub_client.forward_local(
                    local_port=tunnel_local_port,
                    remote_port=tunnel_remote_port,
                    remote_host=test_vars["controller_ip"]):
                client = Connection("127.0.0.1",
                                    user=test_vars["controller_user"],
                                    port=tunnel_local_port,
                                    connect_kwargs={
                                        "key_filename":
                                        test_vars["ssh_priv_key"],
                                    })
                try:
                    client.open()
                except NoValidConnectionsError as ex:
                    exp_err = "Unable to connect to port {} on 127.0.0.1".format(
                        tunnel_local_port)
                    if exp_err not in str(ex):
                        raise
                    else:
                        log.warning("{0} (attempt #{1}, retrying)".format(
                            exp_err, str(port_attempt)))
                        continue

                yield client
            log.debug("{} closed".format(msg_con))
            break  # no need to iterate again
    else:
        log.debug("Opening {}".format(msg_con))
        pub_client.open()
        yield pub_client
        log.debug("Closing {}".format(msg_con))

    pub_client.close()
Пример #2
0
def ssh_con_fabric(test_vars):
    """Create an SSH connection to the controller."""
    log = logging.getLogger("ssh_con_fabric")

    # SSH connection/client to the public IP.
    pub_client = Connection(test_vars["public_ip"],
                            user=test_vars["controller_user"],
                            connect_kwargs={
                                "key_filename": test_vars["ssh_priv_key"],
                            })

    # If the controller's IP is not the same as the public IP, then we are
    # using a jumpbox to get into the VNET containing the controller. In that
    # case, create an SSH tunnel before connecting to the controller.
    msg_con = "SSH connection to controller ({})".format(
        test_vars["controller_ip"])
    if test_vars["public_ip"] != test_vars["controller_ip"]:
        tunnel_local_port = get_unused_local_port()
        tunnel_remote_port = 22

        msg_con += " via jumpbox ({0}), local port {1}".format(
            test_vars["public_ip"], tunnel_local_port)

        log.debug("Opening {}".format(msg_con))
        with pub_client.forward_local(local_port=tunnel_local_port,
                                      remote_port=tunnel_remote_port,
                                      remote_host=test_vars["controller_ip"]):
            client = Connection("127.0.0.1",
                                user=test_vars["controller_user"],
                                port=tunnel_local_port,
                                connect_kwargs={
                                    "key_filename": test_vars["ssh_priv_key"],
                                })
            client.open()
            yield client
        log.debug("{} closed".format(msg_con))
    else:
        log.debug("Opening {}".format(msg_con))
        pub_client.open()
        yield pub_client
        log.debug("Closing {}".format(msg_con))

    pub_client.close()
Пример #3
0
class RemoteRunner:
    """
    Starts Jupyter lab on a remote resource and port forwards session to
    local machine.

    Returns
    -------
    RemoteRunner
        An object that is responsible for connecting to remote host and launching jupyter lab.

    Raises
    ------
    SystemExit
        When the specified local port is not available.
    """

    host: str
    port: int = 8888
    conda_env: str = None
    notebook_dir: str = None
    port_forwarding: bool = True
    launch_command: str = None
    identity: str = None
    shell: str = '/usr/bin/env bash'

    def __post_init__(self):
        self.run_kwargs = dict(pty=True)
        console.rule('[bold green]Authentication', characters='*')
        if self.port_forwarding and not is_port_available(self.port):
            console.log(
                f'''[bold red]Specified port={self.port} is already in use on your local machine. Try a different port'''
            )
            sys.exit(1)

        connect_kwargs = {}
        if self.identity:
            connect_kwargs['key_filename'] = [str(self.identity)]

        self.session = Connection(self.host, connect_kwargs=connect_kwargs, forward_agent=True)
        console.log(
            f'[bold cyan]Authenticating user ({self.session.user}) from client ({socket.gethostname()}) to remote host ({self.session.host})'
        )
        # Try passwordless authentication
        try:
            self.session.open()
        except (
            paramiko.ssh_exception.BadAuthenticationType,
            paramiko.ssh_exception.AuthenticationException,
        ):
            pass

        # Prompt for password and token (2FA)
        if not self.session.is_connected:
            for _ in range(2):
                try:
                    loc_transport = self.session.client.get_transport()
                    try:
                        loc_transport.auth_interactive_dumb(
                            self.session.user, _authentication_handler
                        )
                    except paramiko.ssh_exception.BadAuthenticationType:
                        # It is not clear why auth_interactive_dumb fails in some cases, but
                        # in the examples we could generate auth_password was successful
                        loc_transport.auth_password(self.session.user, getpass.getpass())
                    self.session.transport = loc_transport
                    break
                except Exception:
                    console.log('[bold red]:x: Failed to Authenticate your connection')
            if not self.session.is_connected:
                sys.exit(1)

        console.log('[bold cyan]:white_check_mark: The client is authenticated successfully')

    def _jupyter_info(self, command='sh -c "command -v jupyter"'):
        console.rule('[bold green]Running jupyter sanity checks', characters='*')
        out = self.session.run(command, warn=True, hide='out', **self.run_kwargs)
        if out.failed:
            console.log(f"[bold red]:x: Couldn't find jupyter executable with: '{command}'")
            sys.exit(1)
        console.log('[bold cyan]:white_check_mark: Found jupyter executable')

    def envvar_exists(self, envvar):
        message = 'variable is not defined'
        cmd = f'''printenv {envvar} || echo "{message}"'''
        out = self.session.run(cmd, hide='out', **self.run_kwargs).stdout.strip()
        return message not in out

    def dir_exists(self, directory):
        """
        Checks if a given directory exists on remote host.
        """
        message = "couldn't find the directory"
        cmd = f'''cd {directory} || echo "{message}"'''
        out = self.session.run(cmd, hide='out', **self.run_kwargs).stdout.strip()
        return message not in out

    def setup_port_forwarding(self):
        """
        Sets up SSH port forwarding
        """
        console.rule('[bold green]Setting up port forwarding', characters='*')
        local_port = int(self.port)
        remote_port = int(self.parsed_result['port'])
        with self.session.forward_local(
            local_port,
            remote_port=remote_port,
            remote_host=self.parsed_result['hostname'],
        ):
            time.sleep(
                3
            )  # don't want open_browser to run before the forwarding is actually working
            open_browser(port=local_port, token=self.parsed_result['token'])
            self.session.run(f'tail -f {self.log_file}', pty=True)

    def close(self):
        self.session.close()

    def start(self):
        """
        Launches Jupyter Lab on remote host, sets up ssh tunnel and opens browser on local machine.
        """
        # jupyter lab will pipe output to logfile, which should not exist prior to running
        # Logfile will be in $TMPDIR if defined on the remote machine, otherwise in $HOME

        try:
            check_jupyter_status = 'sh -c "command -v jupyter"'
            if self.conda_env:
                check_jupyter_status = (
                    f'conda activate {self.conda_env} && sh -c "command -v jupyter"'
                )
            self._jupyter_info(check_jupyter_status)
            if self.envvar_exists('TMPDIR') and self.dir_exists('$TMPDIR'):
                self.log_dir = '$TMPDIR'
            elif self.envvar_exists('HOME') and self.dir_exists('$HOME'):
                self.log_dir = '$HOME'
            else:
                message = (
                    '$TMPDIR/ is not a directory'
                    if self.envvar_exists('TMPDIR')
                    else '$TMPDIR is not defined'
                )
                console.log(f'[bold red]{message}')
                message = (
                    '$HOME/ is not a directory'
                    if self.envvar_exists('HOME')
                    else '$HOME is not defined'
                )
                console.log(f'[bold red]{message}')
                console.log('[bold red]Can not determine directory for log file')
                sys.exit(1)

            self.log_dir = f'{self.log_dir}/.jupyter_forward'
            self.session.run(f'mkdir -p {self.log_dir}', **self.run_kwargs)
            timestamp = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
            self.log_file = f'{self.log_dir}/log.{timestamp}'
            self.session.run(f'touch {self.log_file}', **self.run_kwargs)

            command = 'jupyter lab --no-browser'
            if self.launch_command:
                command = f'{command} --ip=\$(hostname)'
            else:
                command = f'{command} --ip=`hostname`'
            if self.notebook_dir:
                command = f'{command} --notebook-dir={self.notebook_dir}'
            command = f'{command} >& {self.log_file}'
            if self.conda_env:
                command = f'conda activate {self.conda_env} && {command}'

            if self.launch_command:
                script_file = f'{self.log_dir}/batch-script.{timestamp}'
                cmd = f"""echo "#!{self.shell}\n\n{command}" > {script_file}"""
                self.session.run(cmd, **self.run_kwargs, echo=True)
                self.session.run(f'chmod +x {script_file}', **self.run_kwargs)
                command = f'{self.launch_command} {script_file}'

            self.session.run(command, asynchronous=True, **self.run_kwargs, echo=False)

            # wait for logfile to contain access info, then write it to screen
            condition = True
            stdout = None
            pattern = 'is running at:'
            with console.status(
                f'[bold cyan]Parsing {self.log_file} log file on {self.session.host} for jupyter information',
                spinner='weather',
            ):
                while condition:
                    try:
                        result = self.session.run(
                            f'cat {self.log_file}', **self.run_kwargs, echo=False
                        )
                        if pattern in result.stdout:
                            condition = False
                            stdout = result.stdout
                    except invoke.exceptions.UnexpectedExit:
                        pass
            self.parsed_result = parse_stdout(stdout)

            if self.port_forwarding:
                self.setup_port_forwarding()
            else:
                open_browser(url=self.parsed_result['url'])
                self.session.run(f'tail -f {self.log_file}', **self.run_kwargs)
        except Exception:
            self.close()

        finally:
            console.rule(
                '[bold red]Terminated the network 📡 connection to the remote end', characters='*'
            )
Пример #4
0
class RemoteRunner:
    """
    Starts Jupyter lab on a remote resource and port forwards session to
    local machine.

    Returns
    -------
    RemoteRunner
        An object that is responsible for connecting to remote host and launching jupyter lab.

    Raises
    ------
    SystemExit
        When the specified local port is not available.
    """

    host: str
    port: int = 8888
    conda_env: str = None
    notebook_dir: str = None
    port_forwarding: bool = True
    launch_command: str = None
    identity: str = None
    shell: str = '/usr/bin/env bash'

    def __post_init__(self):
        if self.port_forwarding and not is_port_available(self.port):
            raise SystemExit((
                f'''Specified port={self.port} is already in use on your local machine. Try a different port'''
            ))

        connect_kwargs = {}
        if self.identity:
            connect_kwargs['key_filename'] = [self.identity]

        self.session = Connection(self.host,
                                  connect_kwargs=connect_kwargs,
                                  forward_agent=True)
        try:
            self.session.open()
        except paramiko.ssh_exception.BadAuthenticationType:
            loc_transport = self.session.client.get_transport()
            loc_transport.auth_interactive_dumb(self.session.user,
                                                _authentication_handler)
            self.session.transport = loc_transport

    def dir_exists(self, directory):
        """
        Checks if a given directory exists on remote host.
        """
        message = "couldn't find the directory"
        cmd = f'''if [[ ! -d "{directory}" ]] ; then echo "{message}"; fi'''
        out = self.session.run(cmd, hide='out').stdout.strip()
        return message not in out

    def setup_port_forwarding(self):
        """
        Sets up SSH port forwarding
        """
        print('**********************************')
        print('*** Setting up port forwarding ***')
        print('**********************************\n\n')
        local_port = int(self.port)
        remote_port = int(self.parsed_result['port'])
        with self.session.forward_local(
                local_port,
                remote_port=remote_port,
                remote_host=self.parsed_result['hostname'],
        ):
            time.sleep(
                3
            )  # don't want open_browser to run before the forwarding is actually working
            open_browser(port=local_port, token=self.parsed_result['token'])
            self.session.run(f'tail -f {self.log_file}', pty=True)

    def close(self):
        self.session.close()

    def start(self):
        """
        Launches Jupyter Lab on remote host, sets up ssh tunnel and opens browser on local machine.
        """
        # jupyter lab will pipe output to logfile, which should not exist prior to running
        # Logfile will be in $TMPDIR if defined on the remote machine, otherwise in $HOME

        try:

            if self.dir_exists('$TMPDIR'):
                self.log_dir = '$TMPDIR'
            else:
                self.log_dir = '$HOME'

            self.log_dir = f'{self.log_dir}/.jupyter_forward'
            kwargs = dict(pty=True, shell=self.shell)
            self.session.run(f'mkdir -p {self.log_dir}', **kwargs)
            timestamp = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
            self.log_file = f'{self.log_dir}/log.{timestamp}'
            self.session.run(f'touch {self.log_file}', **kwargs)

            command = 'jupyter lab --no-browser'
            if self.launch_command:
                command = f'{command} --ip=\$(hostname)'
            else:
                command = f'{command} --ip=`hostname`'

            if self.notebook_dir:
                command = f'{command} --notebook-dir={self.notebook_dir}'

            command = f'{command} > {self.log_file} 2>&1'

            if self.conda_env:
                command = f'conda activate {self.conda_env} && {command}'

            if self.launch_command:
                script_file = f'{self.log_dir}/batch-script.{timestamp}'
                cmd = f"""echo "#!{self.shell}\n\n{command}" > {script_file}"""
                self.session.run(cmd, **kwargs, echo=True)
                self.session.run(f'chmod +x {script_file}', **kwargs)
                command = f'{self.launch_command} {script_file}'

            self.session.run(command, asynchronous=True, **kwargs, echo=True)

            # wait for logfile to contain access info, then write it to screen
            condition = True
            stdout = None
            pattern = 'is running at:'
            while condition:
                try:
                    result = self.session.run(f'cat {self.log_file}', **kwargs)
                    if pattern in result.stdout:
                        condition = False
                        stdout = result.stdout
                except invoke.exceptions.UnexpectedExit:
                    print(
                        f'Trying to access {self.log_file} on {self.session.host} again...'
                    )
                    pass
            self.parsed_result = parse_stdout(stdout)

            if self.port_forwarding:
                self.setup_port_forwarding()
            else:
                open_browser(url=self.parsed_result['url'])
                self.session.run(f'tail -f {self.log_file}', **kwargs)
        finally:
            self.close()
            print(
                '\n***********************************************************'
            )
            print(
                '*** Terminated the network connection to the remote end ***')
            print(
                '***********************************************************\n'
            )