예제 #1
0
    def _establish_ssh_session(self):
        # Connect to remote host.
        try:
            sock = socket.create_connection(
                (str(self._ssh_host), self._ssh_port))
        except Exception:
            log.error("Cannot connect to host '%s' (%s, %d).", self.name,
                      self._ssh_host, self._ssh_port)
            raise

        # SSH handshake.
        ssh_session = Session()
        ssh_session.handshake(sock)

        # Verify host key. Accept keys from previously unknown hosts on first connection.
        hosts = ssh_session.knownhost_init()
        testbed_root = os.path.dirname(os.path.abspath(inspect.stack()[-1][1]))
        known_hosts_path = os.path.join(testbed_root, KNOWN_HOSTS_FILE)
        try:
            hosts.readfile(known_hosts_path)
        except ssh2.exceptions.KnownHostReadFileError:
            pass  # ignore, file is created/overwritten later

        host_key, key_type = ssh_session.hostkey()
        server_type = None
        if key_type == LIBSSH2_HOSTKEY_TYPE_RSA:
            server_type = LIBSSH2_KNOWNHOST_KEY_SSHRSA
        else:
            server_type = LIBSSH2_KNOWNHOST_KEY_SSHDSS
        type_mask = LIBSSH2_KNOWNHOST_TYPE_PLAIN | LIBSSH2_KNOWNHOST_KEYENC_RAW | server_type

        try:
            hosts.checkp(
                str(self._ssh_host).encode('utf-8'), self._ssh_port, host_key,
                type_mask)
        except ssh2.exceptions.KnownHostCheckNotFoundError:
            log.warn("Host key of '%s' (%s, %d) added to known hosts.",
                     self.name, self._ssh_host, self._ssh_port)
            hosts.addc(
                str(self._ssh_host).encode('utf-8'), host_key, type_mask)
            hosts.writefile(known_hosts_path)
        except ssh2.exceptions.KnownHostCheckMisMatchError:
            log.error("Host key of '%s' (%s, %d) does not match known key.",
                      self.name, self._ssh_host, self._ssh_port)
            raise

        # Authenticate at remote host.
        try:
            if self._identity_file is None:
                ssh_session.agent_auth(self._username)
            else:
                ssh_session.userauth_publickey_fromfile(
                    self._username, self._identity_file)
        except Exception:
            log.error("Authentication at host '%s' (%s, %d) failed.",
                      self.name, self._ssh_host, self._ssh_port)
            ssh_session.disconnect()
            raise

        return ssh_session
    LIBSSH2_KNOWNHOST_KEYENC_RAW, LIBSSH2_KNOWNHOST_KEY_SSHRSA

# Connection settings
host = 'localhost'
user = os.getlogin()
known_hosts = os.sep.join([os.path.expanduser('~'), '.ssh', 'known_hosts'])

# Make socket, connect
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, 22))

# Initialise
session = Session()
session.handshake(sock)

host_key, key_type = session.hostkey()

server_key_type = LIBSSH2_KNOWNHOST_KEY_SSHRSA \
                  if key_type == LIBSSH2_HOSTKEY_TYPE_RSA \
                     else LIBSSH2_KNOWNHOST_KEY_SSHDSS

kh = session.knownhost_init()
_read_hosts = kh.readfile(known_hosts)
print("Read %s hosts from known hosts file at %s" % (_read_hosts, known_hosts))

# Verification
type_mask = LIBSSH2_KNOWNHOST_TYPE_PLAIN | \
            LIBSSH2_KNOWNHOST_KEYENC_RAW | \
            server_key_type
kh.checkp(host, 22, host_key, type_mask)
print("Host verification passed.")
예제 #3
0
class SSH2Transport(Transport):
    def __init__(
        self,
        host: str,
        port: int = -1,
        auth_username: str = "",
        auth_private_key: str = "",
        auth_password: str = "",
        auth_strict_key: bool = True,
        timeout_socket: int = 5,
        timeout_transport: int = 5,
        timeout_exit: bool = True,
        ssh_config_file: str = "",
        ssh_known_hosts_file: str = "",
    ) -> None:
        """
        SSH2Transport Object

        Inherit from Transport ABC
        SSH2Transport <- Transport (ABC)

        Args:
            host: host ip/name to connect to
            port: port to connect to
            auth_username: username for authentication
            auth_private_key: path to private key for authentication
            auth_password: password for authentication
            auth_strict_key: True/False to enforce strict key checking (default is True)
            timeout_socket: timeout for establishing socket in seconds
            timeout_transport: timeout for ssh2 transport in seconds
            timeout_exit: True/False close transport if timeout encountered
            ssh_config_file: string to path for ssh config file
            ssh_known_hosts_file: string to path for ssh known hosts file

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        cfg_port, cfg_user, cfg_private_key = self._process_ssh_config(host, ssh_config_file)

        if port == -1:
            port = cfg_port or 22

        super().__init__(
            host,
            port,
            timeout_socket,
            timeout_transport,
            timeout_exit,
        )

        self.auth_username: str = auth_username or cfg_user
        self.auth_private_key: str = auth_private_key or cfg_private_key
        self.auth_password: str = auth_password
        self.auth_strict_key: bool = auth_strict_key
        self.ssh_known_hosts_file: str = ssh_known_hosts_file

        self.session: Session = None
        self.channel: Channel = None

        self.socket = Socket(host=self.host, port=self.port, timeout=self.timeout_socket)

    @staticmethod
    def _process_ssh_config(host: str, ssh_config_file: str) -> Tuple[Optional[int], str, str]:
        """
        Method to parse ssh config file

        In the future this may move to be a "helper" function as it should be very similar between
        paramiko and and ssh2-python... for now it can be a static method as there may be varying
        supported args between the two transport drivers.

        Args:
            host: host to lookup in ssh config file
            ssh_config_file: string path to ssh config file; passed down from `Scrape`, or the
                `NetworkDriver` or subclasses of it, in most cases.

        Returns:
            Tuple: port to use for ssh, username to use for ssh, identity file (private key) to
                use for ssh auth

        Raises:
            N/A

        """
        ssh = SSHConfig(ssh_config_file)
        host_config = ssh.lookup(host)
        return host_config.port, host_config.user or "", host_config.identity_file or ""

    def open(self) -> None:
        """
        Parent method to open session, authenticate and acquire shell

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            Exception: if socket handshake fails
            ScrapliAuthenticationFailed: if all authentication means fail

        """
        if not self.socket.socket_isalive():
            self.socket.socket_open()

        self.session = Session()
        self.set_timeout(self.timeout_transport)
        try:
            self.session.handshake(self.socket.sock)
        except Exception as exc:
            LOG.critical(
                f"Failed to complete handshake with host {self.host}; " f"Exception: {exc}"
            )
            raise exc

        if self.auth_strict_key:
            LOG.debug(f"Attempting to validate {self.host} public key")
            self._verify_key()

        LOG.debug(f"Session to host {self.host} opened")
        self._authenticate()
        if not self._isauthenticated():
            msg = f"Authentication to host {self.host} failed"
            LOG.critical(msg)
            raise ScrapliAuthenticationFailed(msg)

        self._open_channel()

    def _verify_key(self) -> None:
        """
        Verify target host public key, raise exception if invalid/unknown

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            KeyVerificationFailed: if public key verification fails

        """
        known_hosts = SSHKnownHosts(self.ssh_known_hosts_file)

        if self.host not in known_hosts.hosts.keys():
            raise KeyVerificationFailed(f"{self.host} not in known_hosts!")

        remote_server_key_info = self.session.hostkey()
        encoded_remote_server_key = remote_server_key_info[0]
        raw_remote_public_key = base64.encodebytes(encoded_remote_server_key)
        remote_public_key = raw_remote_public_key.replace(b"\n", b"").decode()

        if known_hosts.hosts[self.host]["public_key"] != remote_public_key:
            raise KeyVerificationFailed(
                f"{self.host} in known_hosts but public key does not match!"
            )

    def _authenticate(self) -> None:
        """
        Parent method to try all means of authentication

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            ScrapliAuthenticationFailed: if auth fails

        """
        if self.auth_private_key:
            self._authenticate_public_key()
            if self._isauthenticated():
                LOG.debug(f"Authenticated to host {self.host} with public key auth")
                return
            if not self.auth_password or not self.auth_username:
                msg = (
                    f"Failed to authenticate to host {self.host} with private key "
                    f"`{self.auth_private_key}`. Unable to continue authentication, "
                    "missing username, password, or both."
                )
                LOG.critical(msg)
                raise ScrapliAuthenticationFailed(msg)

        self._authenticate_password()
        if self._isauthenticated():
            LOG.debug(f"Authenticated to host {self.host} with password")
            return
        self._authenticate_keyboard_interactive()
        if self._isauthenticated():
            LOG.debug(f"Authenticated to host {self.host} with keyboard interactive")

    def _authenticate_public_key(self) -> None:
        """
        Attempt to authenticate with public key authentication

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        try:
            self.session.userauth_publickey_fromfile(
                self.auth_username, self.auth_private_key.encode()
            )
        except AuthenticationError as exc:
            LOG.critical(
                f"Public key authentication with host {self.host} failed. Exception: {exc}."
            )
        except SSH2Error as exc:
            LOG.critical(
                "Unknown error occurred during public key authentication with host "
                f"{self.host}; Exception: {exc}"
            )

    def _authenticate_password(self) -> None:
        """
        Attempt to authenticate with password authentication

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            Exception: if unknown (i.e. not auth failed) exception occurs

        """
        try:
            self.session.userauth_password(self.auth_username, self.auth_password)
        except AuthenticationError:
            LOG.critical(
                f"Password authentication with host {self.host} failed. Exception: "
                f"`AuthenticationError`."
            )
        except Exception as exc:
            LOG.critical(
                "Unknown error occurred during password authentication with host "
                f"{self.host}; Exception: {exc}"
            )
            raise exc

    def _authenticate_keyboard_interactive(self) -> None:
        """
        Attempt to authenticate with keyboard interactive authentication

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            Exception: if unknown (i.e. not auth failed) exception occurs

        """
        try:
            self.session.userauth_keyboardinteractive(  # pylint: disable=C0415
                self.auth_username, self.auth_password
            )
        except AttributeError as exc:
            LOG.critical(
                "Keyboard interactive authentication not supported in your ssh2-python version. "
                f"Exception: {exc}"
            )
        except AuthenticationError:
            LOG.critical(
                f"Keyboard interactive authentication with host {self.host} failed. "
                f"Exception: `AuthenticationError`."
            )
        except Exception as exc:
            LOG.critical(
                "Unknown error occurred during keyboard interactive authentication with host "
                f"{self.host}; Exception: {exc}"
            )
            raise exc

    def _isauthenticated(self) -> bool:
        """
        Check if session is authenticated

        Args:
            N/A

        Returns:
            bool: True if authenticated, else False

        Raises:
            N/A

        """
        authenticated: bool = self.session.userauth_authenticated()
        return authenticated

    def _open_channel(self) -> None:
        """
        Open channel, acquire pty, request interactive shell

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        self.channel = self.session.open_session()
        self.channel.pty()
        self.channel.shell()
        LOG.debug(f"Channel to host {self.host} opened")

    def close(self) -> None:
        """
        Close session and socket

        Args:
            N/A

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        self.channel.close()
        LOG.debug(f"Channel to host {self.host} closed")
        self.socket.socket_close()

    def isalive(self) -> bool:
        """
        Check if socket is alive and session is authenticated

        Args:
            N/A

        Returns:
            bool: True if socket is alive and session authenticated, else False

        Raises:
            N/A

        """
        if self.socket.socket_isalive() and not self.channel.eof() and self._isauthenticated():
            return True
        return False

    def read(self) -> bytes:
        """
        Read data from the channel

        Args:
            N/A

        Returns:
            bytes: bytes output as read from channel

        Raises:
            N/A

        """
        output: bytes
        _, output = self.channel.read(65535)
        return output

    def write(self, channel_input: str) -> None:
        """
        Write data to the channel

        Args:
            channel_input: string to send to channel

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        self.channel.write(channel_input)

    def set_timeout(self, timeout: int) -> None:
        """
        Set session timeout

        Args:
            timeout: timeout in seconds

        Returns:
            N/A  # noqa: DAR202

        Raises:
            N/A

        """
        # ssh2-python expects timeout in milliseconds
        self.session.set_timeout(timeout * 1000)
예제 #4
0
class Ssh2Transport(Transport):
    def __init__(self, base_transport_args: BaseTransportArgs,
                 plugin_transport_args: PluginTransportArgs) -> None:
        super().__init__(base_transport_args=base_transport_args)
        self.plugin_transport_args = plugin_transport_args

        self.socket: Optional[Socket] = None
        self.session: Optional[Session] = None
        self.session_channel: Optional[Channel] = None

    def open(self) -> None:
        self._pre_open_closing_log(closing=False)

        if not self.socket:
            self.socket = Socket(
                host=self._base_transport_args.host,
                port=self._base_transport_args.port,
                timeout=self._base_transport_args.timeout_socket,
            )

        if not self.socket.isalive():
            self.socket.open()

        self.session = Session()
        self._set_timeout(value=self._base_transport_args.timeout_transport)

        try:
            self.session.handshake(self.socket.sock)
        except Exception as exc:
            self.logger.critical("failed to complete handshake with host")
            raise ScrapliConnectionNotOpened from exc

        if self.plugin_transport_args.auth_strict_key:
            self.logger.debug(
                f"attempting to validate {self._base_transport_args.host} public key"
            )
            self._verify_key()

        self._authenticate()

        if not self.session.userauth_authenticated():
            msg = "all authentication methods failed"
            self.logger.critical(msg)
            raise ScrapliAuthenticationFailed(msg)

        self._open_channel()

        self._post_open_closing_log(closing=False)

    def _verify_key(self) -> None:
        """
        Verify target host public key, raise exception if invalid/unknown

        Args:
            N/A

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None
            ScrapliAuthenticationFailed: if public key verification fails

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        known_hosts = SSHKnownHosts(
            self.plugin_transport_args.ssh_known_hosts_file)
        known_host_public_key = known_hosts.lookup(
            self._base_transport_args.host)

        if not known_host_public_key:
            raise ScrapliAuthenticationFailed(
                f"{self._base_transport_args.host} not in known_hosts!")

        remote_server_key_info = self.session.hostkey()
        encoded_remote_server_key = remote_server_key_info[0]
        raw_remote_public_key = base64.encodebytes(encoded_remote_server_key)
        remote_public_key = raw_remote_public_key.replace(b"\n", b"").decode()

        if known_host_public_key["public_key"] != remote_public_key:
            raise ScrapliAuthenticationFailed(
                f"{self._base_transport_args.host} in known_hosts but public key does not match!"
            )

    def _authenticate(self) -> None:
        """
        Parent method to try all means of authentication

        Args:
            N/A

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None
            ScrapliAuthenticationFailed: if auth fails

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        if self.plugin_transport_args.auth_private_key:
            self._authenticate_public_key()
            if self.session.userauth_authenticated():
                return
            if (not self.plugin_transport_args.auth_password
                    or not self.plugin_transport_args.auth_username):
                msg = (
                    f"Failed to authenticate to host {self._base_transport_args.host} with private "
                    f"key `{self.plugin_transport_args.auth_private_key}`. Unable to continue "
                    "authentication, missing username, password, or both.")
                raise ScrapliAuthenticationFailed(msg)

        self._authenticate_password()

    def _authenticate_public_key(self) -> None:
        """
        Attempt to authenticate with public key authentication

        Args:
            N/A

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        try:
            self.session.userauth_publickey_fromfile(
                self.plugin_transport_args.auth_username,
                self.plugin_transport_args.auth_private_key.encode(),
                self.plugin_transport_args.auth_private_key_passphrase,
            )
        except (AuthenticationError, SSH2Error):
            pass

    def _authenticate_password(self) -> None:
        """
        Attempt to authenticate with password authentication

        Args:
            N/A

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        try:
            self.session.userauth_password(
                username=self.plugin_transport_args.auth_username,
                password=self.plugin_transport_args.auth_password,
            )
            return
        except AuthenticationError:
            pass
        try:
            self.session.userauth_keyboardinteractive(
                self.plugin_transport_args.auth_username,
                self.plugin_transport_args.auth_password)
        except AttributeError:
            msg = (
                "Keyboard interactive authentication may not be supported in your "
                "ssh2-python version.")
            self.logger.warning(msg)
        except AuthenticationError:
            pass

    def _open_channel(self) -> None:
        """
        Open channel, acquire pty, request interactive shell

        Args:
            N/A

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        self.session_channel = self.session.open_session()
        self.session_channel.pty()
        self.session_channel.shell()

    def close(self) -> None:
        self._pre_open_closing_log(closing=True)

        if self.session_channel:
            self.session_channel.close()

            if self.socket:
                self.socket.close()

        self.session = None
        self.session_channel = None

        self._post_open_closing_log(closing=True)

    def isalive(self) -> bool:
        if not self.session_channel:
            return False
        return not self.session_channel.eof()

    def read(self) -> bytes:
        if not self.session_channel:
            raise ScrapliConnectionNotOpened
        try:
            buf: bytes
            _, buf = self.session_channel.read(65535)
        except Exception as exc:
            msg = (
                "encountered EOF reading from transport; typically means the device closed the "
                "connection")
            self.logger.critical(msg)
            raise ScrapliConnectionError(msg) from exc
        return buf

    def write(self, channel_input: bytes) -> None:
        if not self.session_channel:
            raise ScrapliConnectionNotOpened
        self.session_channel.write(channel_input)

    def _set_timeout(self, value: float) -> None:
        """
        Set session object timeout value

        Args:
            value: timeout in seconds

        Returns:
            None

        Raises:
            ScrapliConnectionNotOpened: if session is unopened/None

        """
        if not self.session:
            raise ScrapliConnectionNotOpened

        # ssh2-python expects timeout in milliseconds
        self.session.set_timeout(value * 1000)