Exemplo n.º 1
0
    def test_save_host_keys(self):
        """
        verify that SSHClient correctly saves a known_hosts file.
        """
        warnings.filterwarnings("ignore", "tempnam.*")

        host_key = paramiko.RSAKey.from_private_key_file(
            _support("test_rsa.key"))
        public_host_key = paramiko.RSAKey(data=host_key.asbytes())
        fd, localname = mkstemp()
        os.close(fd)

        client = SSHClient()
        assert len(client.get_host_keys()) == 0

        host_id = "[%s]:%d" % (self.addr, self.port)

        client.get_host_keys().add(host_id, "ssh-rsa", public_host_key)
        assert len(client.get_host_keys()) == 1
        assert public_host_key == client.get_host_keys()[host_id]["ssh-rsa"]

        client.save_host_keys(localname)

        with open(localname) as fd:
            assert host_id in fd.read()

        os.unlink(localname)
Exemplo n.º 2
0
class SSH_Client (autosuper) :

    def __init__ \
        (self, host, privkey, remote_dir = '/tmp', local_dir = '/tmp'
        , password = None, port = 22, user = '******'
        ) :
        self.ssh        = SSHClient ()
        self.host       = host
        self.remote_dir = remote_dir
        self.local_dir  = local_dir
        self.key = RSAKey.from_private_key_file (privkey, password = password)
        home = os.environ.get ('HOME', '/root')
        path = os.path.join (home, '.ssh', 'known_hosts_paramiko')
        self.known_hosts = path
        self.ssh.load_host_keys (path)
        self.ssh.set_missing_host_key_policy (AutoAddPolicy ())
        self.ssh.connect \
            ( host
            , pkey          = self.key
            , port          = port
            , username      = user
            , look_for_keys = False
            , allow_agent   = False
            )
        self.sftp = self.ssh.open_sftp ()
        self.sftp.chdir (self.remote_dir)
    # end def __init__

    def get_files (self, * fn) :
        for f in fn :
            dest = os.path.join (self.local_dir, os.path.basename (f))
            self.sftp.get (f, dest)
    # end def get_files

    def list_files (self) :
        for f in self.sftp.listdir_attr () :
            if stat.S_ISREG (f.st_mode) :
                yield (f.filename)
    # end def list_files

    def put_files (self, *fn) :
        for f in fn :
            dest = os.path.join (self.remote_dir, os.path.basename (f))
            self.sftp.put (f, dest)
    # end def put_files

    def close (self) :
        self.ssh.save_host_keys (self.known_hosts)
        self.ssh.close ()
    # end def close

    def __getattr__ (self, name) :
        return getattr (self.sftp, name)
Exemplo n.º 3
0
class SSHTransport(CommonTransport):
    DEFAULT_SSH_PORT = 22

    def __init__(self,
                 host,
                 user=None,
                 password=None,
                 pkey=None,
                 pkey_file=None,
                 port=22,
                 timeout=30,
                 auto_add=True,
                 cmd_timeout=30.0,
                 ignore_known_hosts=True):
        """
        Constructor for common SSH trasport class.
        pkey        : Key object used to sign and verify SSH2 data;
        pkey_file   : File name of the private keys to try for
                      SSH authentication;
        timeout     : Timeout value for establishing a ssh connection;
        auto_add    : If the specified mamagement node cannot be identified,
                      whether I should auto add it to known host list and
                      save to a local file;
        cmd_timeout : Timeout valeu for a command to return;
        """
        super(SSHTransport, self).__init__()
        self.host = host
        self.user = user
        self.password = password
        self.pkey = pkey
        self.private_key = pkey_file
        self.port = port
        self.timeout = timeout
        self.cmd_exec_timeout = cmd_timeout
        self.connected_endpoint = None
        self.auto_add_unknown_hosts = auto_add
        self.transport = SSHClient()
        self.sftp_client = None
        self.is_client_connected = False
        self.is_known_hosts_ignored = ignore_known_hosts
        self.svc_client_host_keys_file = \
            "%s/xsf_known_hosts" % os.path.expanduser("~")
        if self.pkey is not None:
            if not isinstance(self.pkey, PKey):
                xlog.debug(TransportMessages.SSH_INCORRECT_PRIVATE_KEY)
                raise IncorrectCredentials(
                    message=TransportMessages.SSH_INCORRECT_PRIVATE_KEY)

        if (os.path.isfile(self.svc_client_host_keys_file)
                and self.is_known_hosts_ignored is False):
            try:
                self.transport.load_host_keys(
                    filename=self.svc_client_host_keys_file)
            except IOError as ex:
                xlog.debug(ex)
                raise IncorrectCredentials(
                    message=TransportMessages.SSH_FAILED_TO_LOAD_KNOWN_HOSTS,
                    original_exception=ex)

    def __str__(self):
        if self.port == self.DEFAULT_SSH_PORT:
            return self.host
        return 'ssh://%s:%s' % (self.host, self.port)

    def __repr__(self):
        return '<%s>' % self

    @contextmanager
    def _exception_handler(self):
        try:
            yield
        except BadAuthenticationType as ex:
            xlog.debug(ex)
            raise BadAuthenticationTypeException(
                allowed_types=ex.allowed_types, original_exception=ex)
        except BadHostKeyException as ex:
            xlog.debug(ex)
            raise BadHostFingerPrintException(hostname=ex.hostname,
                                              expected_key=ex.expected_key,
                                              presented_key=ex.key,
                                              original_exception=ex)
        except PartialAuthentication as ex:
            xlog.debug(ex)
            raise PartialAuthenticationException(
                allowed_types=ex.allowed_types, original_exception=ex)
        except PasswordRequiredException as ex:
            xlog.debug(ex)
            raise PassphraseRequiredException(
                message=TransportMessages.SSH_PASS_PHRASE_REQUIRED,
                original_exception=ex)
        except AuthenticationException as ex:
            xlog.debug(ex)
            raise IncorrectCredentials(
                message=TransportMessages.SSH_AUTHENTICATION_FAILURE,
                original_exception=ex)
        except socket.timeout as ex:
            xlog.debug(ex)
            raise ConnectionTimedoutException(
                message=TransportMessages.SSH_CONNECT_TIMED_OUT,
                original_exception=ex)
        except (socket.herror, socket.gaierror) as ex:
            xlog.debug(ex)
            raise HostDoesNotExistException(hostname=self.host,
                                            original_exception=ex)
        except (SSHException, socket.error) as ex:
            xlog.debug(ex)
            raise UnableToConnectException(
                message=TransportMessages.SSH_UNABLE_TO_CONNECT(self.host),
                original_exception=ex)

    def connect(self):
        """
        Initialize a SSH connection with properties
        which were set up in constructor.
        """
        with self._exception_handler():
            if self.auto_add_unknown_hosts:
                self.transport.set_missing_host_key_policy(
                    paramiko.AutoAddPolicy())
            self.transport.connect(self.host,
                                   port=self.port,
                                   username=self.user,
                                   password=self.password,
                                   pkey=self.pkey,
                                   key_filename=self.private_key,
                                   timeout=self.timeout)
            self.connected_endpoint = self.host
            self.is_client_connected = True
            if self.is_known_hosts_ignored is False:
                self.transport.save_host_keys(self.svc_client_host_keys_file)

    def is_connected(self):
        return self.is_client_connected

    def disconnect(self):
        """
        Disconnect from the SSH server.
        """
        if self.is_client_connected:
            self.transport.close()
            self.is_client_connected = False

    def reconnect(self):
        self.disconnect()
        self.connect()

    def send_command(self,
                     command,
                     buf_size=-1,
                     raw=False,
                     timeout=0,
                     stdin_input=None):
        with self._exception_handler():
            # return self.transport.exec_command(command)
            try:
                channel = self.transport.get_transport().open_session()
                channel.settimeout(timeout or self.cmd_exec_timeout)
                channel.exec_command(command)
                stdin = channel.makefile('wb', buf_size)
                stdout = channel.makefile('rb', buf_size)
                stderr = channel.makefile_stderr('rb', buf_size)
                if stdin_input is not None:
                    stdin.write(stdin_input)
                    stdin.flush()
                    # shutdown_write to close write channel, paramiko will send
                    # EOF to device.
                    channel.shutdown_write()
                if raw:  # gain performance without spliting line
                    return stdin, stdout.read(), stderr.read()
                return stdin, stdout.readlines(), stderr.readlines()
            except socket.timeout as ex:
                # Need to reconnect to terminate the remote method call
                self.reconnect()
                xlog.error(ex)
                raise ConnectionTimedoutException(
                    message=TransportMessages.SSH_CON_TIMED_OUT_WHEN_EXEC_CMD,
                    original_exception=ex)
Exemplo n.º 4
0
class SFTPStorage(AbstractStorage):
    def __init__(self, params: StorageConfig):
        self.__log = logging.getLogger(__name__)
        self.__conn = self._create_connection(params)

        self.__conn.chdir(params.backups_path)
        self.__dir = Path(params.backups_path)

    def __del__(self):
        if hasattr(self, '_SFTPStorage__conn'):
            self.__log.debug('Closing connection')
            self.__conn.close()

    def _create_connection(self, params: StorageConfig) -> SFTPClient:
        self.__log.debug('Creating connection to SSH server ' + params['host'])
        self.__ssh = SSHClient()
        if 'enableHostKeys' not in params or params['enableHostKeys']:
            self.__ssh.load_system_host_keys(filename=params.get('hostKeysFilePath'))

        should_save_host_keys = False
        if 'knownHostsPolicy' in params:
            policy: str = params['knownHostsPolicy'].lower()
            if policy == 'reject':
                self.__ssh.set_missing_host_key_policy(RejectPolicy)
            elif policy == 'auto-add':
                self.__ssh.set_missing_host_key_policy(AutoAddPolicy)
                should_save_host_keys = True
            elif policy == 'ignore':
                self.__ssh.set_missing_host_key_policy(WarningPolicy)

        pkey = None
        if 'privateKey' in params:
            pkey = PKey(data=params['privateKey'])
        self.__ssh.connect(hostname=params['host'],
                           port=params.get('port', 22),
                           username=params.get('user'),
                           password=params.get('password'),
                           pkey=pkey,
                           key_filename=params.get('privateKeyPath'),
                           allow_agent=params.get('allowAgent'),
                           compress=params.get('compress'))

        if should_save_host_keys and 'hostKeysFilePath' in params:
            self.__ssh.save_host_keys(filename=params['hostKeysFilePath'])

        self.__log.debug('Starting SFTP client')
        return self.__ssh.open_sftp()

    def list_directory(self, path: Union[str, Path]) -> List[str]:
        path = self.__dir / path
        self.__log.debug(f'Retrieving contents of directory {path}')
        return self.__conn.listdir(str(path))

    def create_folder(self, name: str, parent: Union[Path, str] = '.') -> str:
        path = self.__dir / parent
        self.__conn.chdir(str(path))
        if name not in self.list_directory(parent):
            self.__log.info(f'Creating folder "{path / name}"')
            self.__conn.mkdir(name)
        else:
            self.__log.debug(f'Folder "{path / name}"" already exists')
        return str((path / name).relative_to(self.__dir))

    def upload(self, path: Path, parent: Union[Path, str] = '.'):
        dir_path = self.__dir / parent
        self.__conn.chdir(str(dir_path))
        self.__log.info(f'Uploading file {path} to {parent}')
        self.__conn.put(str(path), path.name, confirm=True)

    def delete(self, path: Union[Path, str]):
        path = self.__dir / path
        path_stats = self.__conn.stat(str(path))
        self.__log.info(f'Deleting {path}')
        if stat.S_ISDIR(path_stats.st_mode):
            entries_in_dir = self.__conn.listdir(str(path))
            for entry in entries_in_dir:
                self.delete(path / entry)
            self.__conn.rmdir(str(path))
        else:
            self.__conn.unlink(str(path))