コード例 #1
0
ファイル: test_ssh.py プロジェクト: zlisde/airflow
 def test_ssh_connection_no_connection_id(self):
     hook = SSHHook(remote_host='localhost')
     self.assertIsNone(hook.ssh_conn_id)
     with hook.get_conn() as client:
         # Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
         (_, stdout, _) = client.exec_command('ls')  # pylint: disable=no-member
         self.assertIsNotNone(stdout.read())
コード例 #2
0
ファイル: test_ssh.py プロジェクト: zlisde/airflow
 def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(
         self, ssh_client):
     hook = SSHHook(ssh_conn_id=self.
                    CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
     assert hook.host_key is None
     with hook.get_conn():
         assert ssh_client.return_value.connect.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.called is False
コード例 #3
0
ファイル: test_ssh.py プロジェクト: zlisde/airflow
 def test_ssh_connection_with_host_key_where_no_host_key_check_is_false(
         self, ssh_client):
     hook = SSHHook(ssh_conn_id=self.
                    CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE)
     assert hook.host_key.get_base64() == TEST_HOST_KEY
     with hook.get_conn():
         assert ssh_client.return_value.connect.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.call_args == mock.call(
             hook.remote_host, 'ssh-rsa', hook.host_key)
コード例 #4
0
ファイル: s3_to_sftp.py プロジェクト: kushsharma/airflow
    def execute(self, context: 'Context') -> None:
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        s3_client = s3_hook.get_conn()
        sftp_client = ssh_hook.get_conn().open_sftp()

        with NamedTemporaryFile("w") as f:
            s3_client.download_file(self.s3_bucket, self.s3_key, f.name)
            sftp_client.put(f.name, self.sftp_path)
コード例 #5
0
ファイル: sftp_to_s3.py プロジェクト: folly3/airflow-1
    def execute(self, context):
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        with NamedTemporaryFile("w") as f:
            sftp_client.get(self.sftp_path, f.name)

            s3_hook.load_file(filename=f.name,
                              key=self.s3_key,
                              bucket_name=self.s3_bucket,
                              replace=True)
コード例 #6
0
    def execute(self, context: 'Context') -> None:
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        if self.use_temp_file:
            with NamedTemporaryFile("w") as f:
                sftp_client.get(self.sftp_path, f.name)

                s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True)
        else:
            with sftp_client.file(self.sftp_path, mode='rb') as data:
                s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info)
コード例 #7
0
    def test_ssh_connection_without_password(self, ssh_mock):
        hook = SSHHook(remote_host='remote_host',
                       port='port',
                       username='******',
                       timeout=10,
                       key_file='fake.file')

        with hook.get_conn():
            ssh_mock.return_value.connect.assert_called_once_with(
                hostname='remote_host',
                username='******',
                key_filename='fake.file',
                timeout=10,
                compress=True,
                port='port',
                sock=None)
コード例 #8
0
    def test_ssh_connection_with_private_key_extra(self, ssh_mock):
        hook = SSHHook(
            ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
            remote_host='remote_host',
            port='port',
            username='******',
            timeout=10,
        )

        with hook.get_conn():
            ssh_mock.return_value.connect.assert_called_once_with(
                hostname='remote_host',
                username='******',
                pkey=TEST_PKEY,
                timeout=10,
                compress=True,
                port='port',
                sock=None)
コード例 #9
0
class SFTPOperator(BaseOperator):
    """
    SFTPOperator for transferring files from remote host to local or vice a versa.
    This operator uses ssh_hook to open sftp transport channel that serve as basis
    for file transfer.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: connection id from airflow Connections.
        `ssh_conn_id` will be ignored if `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param local_filepath: local file path to get or put. (templated)
    :type local_filepath: str
    :param remote_filepath: remote file path to get or put. (templated)
    :type remote_filepath: str
    :param operation: specify operation 'get' or 'put', defaults to put
    :type operation: str
    :param confirm: specify if the SFTP operation should be confirmed, defaults to True
    :type confirm: bool
    :param create_intermediate_dirs: create missing intermediate directories when
        copying from remote to local and vice-versa. Default is False.

        Example: The following task would copy ``file.txt`` to the remote host
        at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
        don't exist. If the parameter is not passed it would error as the directory
        does not exist. ::

            put_file = SFTPOperator(
                task_id="test_sftp",
                ssh_conn_id="ssh_default",
                local_filepath="/tmp/file.txt",
                remote_filepath="/tmp/tmp1/tmp2/file.txt",
                operation="put",
                create_intermediate_dirs=True,
                dag=dag
            )

    :type create_intermediate_dirs: bool
    """
    template_fields = ('local_filepath', 'remote_filepath', 'remote_host')

    @apply_defaults
    def __init__(self,
                 *,
                 ssh_hook=None,
                 ssh_conn_id=None,
                 remote_host=None,
                 local_filepath=None,
                 remote_filepath=None,
                 operation=SFTPOperation.PUT,
                 confirm=True,
                 create_intermediate_dirs=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.local_filepath = local_filepath
        self.remote_filepath = remote_filepath
        self.operation = operation
        self.confirm = confirm
        self.create_intermediate_dirs = create_intermediate_dirs
        if not (self.operation.lower() == SFTPOperation.GET
                or self.operation.lower() == SFTPOperation.PUT):
            raise TypeError(
                "unsupported operation value {0}, expected {1} or {2}".format(
                    self.operation, SFTPOperation.GET, SFTPOperation.PUT))

    def execute(self, context):
        file_msg = None
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. "
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            with self.ssh_hook.get_conn() as ssh_client:
                sftp_client = ssh_client.open_sftp()
                if self.operation.lower() == SFTPOperation.GET:
                    local_folder = os.path.dirname(self.local_filepath)
                    if self.create_intermediate_dirs:
                        Path(local_folder).mkdir(parents=True, exist_ok=True)
                    file_msg = "from {0} to {1}".format(
                        self.remote_filepath, self.local_filepath)
                    self.log.info("Starting to transfer %s", file_msg)
                    sftp_client.get(self.remote_filepath, self.local_filepath)
                else:
                    remote_folder = os.path.dirname(self.remote_filepath)
                    if self.create_intermediate_dirs:
                        _make_intermediate_dirs(
                            sftp_client=sftp_client,
                            remote_directory=remote_folder,
                        )
                    file_msg = "from {0} to {1}".format(
                        self.local_filepath, self.remote_filepath)
                    self.log.info("Starting to transfer file %s", file_msg)
                    sftp_client.put(self.local_filepath,
                                    self.remote_filepath,
                                    confirm=self.confirm)

        except Exception as e:
            raise AirflowException(
                "Error while transferring {0}, error: {1}".format(
                    file_msg, str(e)))

        return self.local_filepath
コード例 #10
0
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: connection id from airflow Connections.
        `ssh_conn_id` will be ignored if `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param timeout: timeout (in seconds) for executing the command.
    :type timeout: int
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :type environment: dict
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :type get_pty: bool
    """

    template_fields = ('command', 'remote_host')
    template_ext = ('.sh', )

    @apply_defaults
    def __init__(self,
                 ssh_hook=None,
                 ssh_conn_id=None,
                 remote_host=None,
                 command=None,
                 timeout=10,
                 environment=None,
                 get_pty=False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.environment = environment
        self.get_pty = self.command.startswith('sudo') or get_pty

    def execute(self, context):
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. " +
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                            timeout=self.timeout)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. " +
                    "It will replace the remote_host which was defined " +
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            if not self.command:
                raise AirflowException("SSH command not specified. Aborting.")

            with self.ssh_hook.get_conn() as ssh_client:
                self.log.info("Running command: %s", self.command)

                # set timeout taken as params
                stdin, stdout, stderr = ssh_client.exec_command(
                    command=self.command,
                    get_pty=self.get_pty,
                    timeout=self.timeout,
                    environment=self.environment)
                # get channels
                channel = stdout.channel

                # closing stdin
                stdin.close()
                channel.shutdown_write()

                agg_stdout = b''
                agg_stderr = b''

                # capture any initial output in case channel is closed already
                stdout_buffer_length = len(stdout.channel.in_buffer)

                if stdout_buffer_length > 0:
                    agg_stdout += stdout.channel.recv(stdout_buffer_length)

                # read from both stdout and stderr
                while not channel.closed or \
                        channel.recv_ready() or \
                        channel.recv_stderr_ready():
                    readq, _, _ = select([channel], [], [], self.timeout)
                    for c in readq:
                        if c.recv_ready():
                            line = stdout.channel.recv(len(c.in_buffer))
                            line = line
                            agg_stdout += line
                            self.log.info(line.decode('utf-8').strip('\n'))
                        if c.recv_stderr_ready():
                            line = stderr.channel.recv_stderr(
                                len(c.in_stderr_buffer))
                            line = line
                            agg_stderr += line
                            self.log.warning(line.decode('utf-8').strip('\n'))
                    if stdout.channel.exit_status_ready()\
                            and not stderr.channel.recv_stderr_ready()\
                            and not stdout.channel.recv_ready():
                        stdout.channel.shutdown_read()
                        stdout.channel.close()
                        break

                stdout.close()
                stderr.close()

                exit_status = stdout.channel.recv_exit_status()
                if exit_status == 0:
                    enable_pickling = conf.getboolean('core',
                                                      'enable_xcom_pickling')
                    if enable_pickling:
                        return agg_stdout
                    else:
                        return b64encode(agg_stdout).decode('utf-8')

                else:
                    error_msg = agg_stderr.decode('utf-8')
                    raise AirflowException(
                        "error running cmd: {0}, error: {1}".format(
                            self.command, error_msg))

        except Exception as e:
            raise AirflowException("SSH operator error: {0}".format(str(e)))

        return True

    def tunnel(self):
        ssh_client = self.ssh_hook.get_conn()
        ssh_client.get_transport()
コード例 #11
0
ファイル: ssh.py プロジェクト: abhinavkumar195/airflow
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
        from airflow Connections. `ssh_conn_id` will be ignored if
        `ssh_hook` is provided.
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :param command: command to execute on remote host. (templated)
    :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds.
        Nullable. If provided, it will replace the `conn_timeout` which was
        predefined in the connection of `ssh_conn_id`.
    :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds.
    :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds.
        Use conn_timeout and cmd_timeout parameters instead.
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :param banner_timeout: timeout to wait for banner from the server in seconds
    """

    template_fields: Sequence[str] = ('command', 'remote_host')
    template_ext: Sequence[str] = ('.sh', )
    template_fields_renderers = {"command": "bash"}

    def __init__(
        self,
        *,
        ssh_hook: Optional["SSHHook"] = None,
        ssh_conn_id: Optional[str] = None,
        remote_host: Optional[str] = None,
        command: Optional[str] = None,
        timeout: Optional[int] = None,
        conn_timeout: Optional[int] = None,
        cmd_timeout: Optional[int] = None,
        environment: Optional[dict] = None,
        get_pty: bool = False,
        banner_timeout: float = 30.0,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.conn_timeout = conn_timeout
        self.cmd_timeout = cmd_timeout
        if self.conn_timeout is None and self.timeout:
            self.conn_timeout = self.timeout
        if self.cmd_timeout is None:
            self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
        self.environment = environment
        self.get_pty = get_pty
        self.banner_timeout = banner_timeout

        if self.timeout:
            warnings.warn(
                'Parameter `timeout` is deprecated.'
                'Please use `conn_timeout` and `cmd_timeout` instead.'
                'The old option `timeout` will be removed in a future version.',
                DeprecationWarning,
                stacklevel=2,
            )

    def get_hook(self) -> "SSHHook":
        from airflow.providers.ssh.hooks.ssh import SSHHook

        if self.ssh_conn_id:
            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                self.log.info(
                    "ssh_conn_id is ignored when ssh_hook is provided.")
            else:
                self.log.info(
                    "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                )
                self.ssh_hook = SSHHook(
                    ssh_conn_id=self.ssh_conn_id,
                    conn_timeout=self.conn_timeout,
                    banner_timeout=self.banner_timeout,
                )

        if not self.ssh_hook:
            raise AirflowException(
                "Cannot operate without ssh_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.log.info(
                "remote_host is provided explicitly. "
                "It will replace the remote_host which was defined "
                "in ssh_hook or predefined in connection of ssh_conn_id.")
            self.ssh_hook.remote_host = self.remote_host

        return self.ssh_hook

    def get_ssh_client(self) -> "SSHClient":
        # Remember to use context manager or call .close() on this when done
        self.log.info('Creating ssh_client')
        return self.get_hook().get_conn()

    def exec_ssh_client_command(self, ssh_client: "SSHClient", command: str):
        warnings.warn(
            'exec_ssh_client_command method on SSHOperator is deprecated, call '
            '`ssh_hook.exec_ssh_client_command` instead',
            DeprecationWarning,
        )
        assert self.ssh_hook
        return self.ssh_hook.exec_ssh_client_command(
            ssh_client,
            command,
            timeout=self.timeout,
            environment=self.environment,
            get_pty=self.get_pty)

    def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
        if exit_status != 0:
            raise AirflowException(
                f"SSH operator error: exit status = {exit_status}")

    def run_ssh_client_command(self, ssh_client: "SSHClient",
                               command: str) -> bytes:
        assert self.ssh_hook
        exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command(
            ssh_client,
            command,
            timeout=self.timeout,
            environment=self.environment,
            get_pty=self.get_pty)
        self.raise_for_status(exit_status, agg_stderr)
        return agg_stdout

    def execute(self, context=None) -> Union[bytes, str]:
        result: Union[bytes, str]
        if self.command is None:
            raise AirflowException(
                "SSH operator error: SSH command not specified. Aborting.")

        # Forcing get_pty to True if the command begins with "sudo".
        self.get_pty = self.command.startswith('sudo') or self.get_pty

        with self.get_ssh_client() as ssh_client:
            result = self.run_ssh_client_command(ssh_client, self.command)
        enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
        if not enable_pickling:
            result = b64encode(result).decode('utf-8')
        return result

    def tunnel(self) -> None:
        """Get ssh tunnel"""
        ssh_client = self.ssh_hook.get_conn()  # type: ignore[union-attr]
        ssh_client.get_transport()
コード例 #12
0
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
        from airflow Connections. `ssh_conn_id` will be ignored if
        `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds.
        Nullable. If provided, it will replace the `conn_timeout` which was
        predefined in the connection of `ssh_conn_id`.
    :type conn_timeout: int
    :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds.
    :type cmd_timeout: int
    :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds.
        Use conn_timeout and cmd_timeout parameters instead.
    :type timeout: int
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :type environment: dict
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :type get_pty: bool
    """

    template_fields = ('command', 'remote_host')
    template_ext = ('.sh', )
    template_fields_renderers = {"command": "bash"}

    def __init__(
        self,
        *,
        ssh_hook: Optional[SSHHook] = None,
        ssh_conn_id: Optional[str] = None,
        remote_host: Optional[str] = None,
        command: Optional[str] = None,
        timeout: Optional[int] = None,
        conn_timeout: Optional[int] = None,
        cmd_timeout: Optional[int] = None,
        environment: Optional[dict] = None,
        get_pty: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.conn_timeout = conn_timeout
        self.cmd_timeout = cmd_timeout
        if self.conn_timeout is None and self.timeout:
            self.conn_timeout = self.timeout
        if self.cmd_timeout is None:
            self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
        self.environment = environment
        self.get_pty = get_pty

        if self.timeout:
            warnings.warn(
                'Parameter `timeout` is deprecated.'
                'Please use `conn_timeout` and `cmd_timeout` instead.'
                'The old option `timeout` will be removed in a future version.',
                DeprecationWarning,
                stacklevel=1,
            )

    def get_hook(self) -> SSHHook:
        if self.ssh_conn_id:
            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                self.log.info(
                    "ssh_conn_id is ignored when ssh_hook is provided.")
            else:
                self.log.info(
                    "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                )
                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                        conn_timeout=self.conn_timeout)

        if not self.ssh_hook:
            raise AirflowException(
                "Cannot operate without ssh_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.log.info(
                "remote_host is provided explicitly. "
                "It will replace the remote_host which was defined "
                "in ssh_hook or predefined in connection of ssh_conn_id.")
            self.ssh_hook.remote_host = self.remote_host

        return self.ssh_hook

    def get_ssh_client(self) -> SSHClient:
        # Remember to use context manager or call .close() on this when done
        self.log.info('Creating ssh_client')
        return self.get_hook().get_conn()

    def exec_ssh_client_command(self, ssh_client: SSHClient,
                                command: str) -> Tuple[int, bytes, bytes]:
        self.log.info("Running command: %s", command)

        # set timeout taken as params
        stdin, stdout, stderr = ssh_client.exec_command(
            command=command,
            get_pty=self.get_pty,
            timeout=self.timeout,
            environment=self.environment,
        )
        # get channels
        channel = stdout.channel

        # closing stdin
        stdin.close()
        channel.shutdown_write()

        agg_stdout = b''
        agg_stderr = b''

        # capture any initial output in case channel is closed already
        stdout_buffer_length = len(stdout.channel.in_buffer)

        if stdout_buffer_length > 0:
            agg_stdout += stdout.channel.recv(stdout_buffer_length)

        # read from both stdout and stderr
        while not channel.closed or channel.recv_ready(
        ) or channel.recv_stderr_ready():
            readq, _, _ = select([channel], [], [], self.cmd_timeout)
            for recv in readq:
                if recv.recv_ready():
                    line = stdout.channel.recv(len(recv.in_buffer))
                    agg_stdout += line
                    self.log.info(line.decode('utf-8', 'replace').strip('\n'))
                if recv.recv_stderr_ready():
                    line = stderr.channel.recv_stderr(
                        len(recv.in_stderr_buffer))
                    agg_stderr += line
                    self.log.warning(
                        line.decode('utf-8', 'replace').strip('\n'))
            if (stdout.channel.exit_status_ready()
                    and not stderr.channel.recv_stderr_ready()
                    and not stdout.channel.recv_ready()):
                stdout.channel.shutdown_read()
                try:
                    stdout.channel.close()
                except Exception:
                    # there is a race that when shutdown_read has been called and when
                    # you try to close the connection, the socket is already closed
                    # We should ignore such errors (but we should log them with warning)
                    self.log.warning("Ignoring exception on close",
                                     exc_info=True)
                break

        stdout.close()
        stderr.close()

        exit_status = stdout.channel.recv_exit_status()

        return exit_status, agg_stdout, agg_stderr

    def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
        if exit_status != 0:
            error_msg = stderr.decode('utf-8')
            raise AirflowException(
                f"error running cmd: {self.command}, error: {error_msg}")

    def run_ssh_client_command(self, ssh_client: SSHClient,
                               command: str) -> bytes:
        exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command(
            ssh_client, command)
        self.raise_for_status(exit_status, agg_stderr)
        return agg_stdout

    def execute(self, context=None) -> Union[bytes, str]:
        result: Union[bytes, str]
        if self.command is None:
            raise AirflowException(
                "SSH operator error: SSH command not specified. Aborting.")

        # Forcing get_pty to True if the command begins with "sudo".
        self.get_pty = self.command.startswith('sudo') or self.get_pty

        try:
            with self.get_ssh_client() as ssh_client:
                result = self.run_ssh_client_command(ssh_client, self.command)
        except Exception as e:
            raise AirflowException(f"SSH operator error: {str(e)}")
        enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
        if not enable_pickling:
            result = b64encode(result).decode('utf-8')
        return result

    def tunnel(self) -> None:
        """Get ssh tunnel"""
        ssh_client = self.ssh_hook.get_conn()  # type: ignore[union-attr]
        ssh_client.get_transport()
コード例 #13
0
ファイル: test_ssh.py プロジェクト: zlisde/airflow
 def test_ssh_connection_with_host_key_extra(self, ssh_client):
     hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
     assert hook.host_key is None  # Since default no_host_key_check = True unless explicit override
     with hook.get_conn():
         assert ssh_client.return_value.connect.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.called is False
コード例 #14
0
 def test_ssh_connection(self):
     hook = SSHHook(ssh_conn_id='ssh_default')
     with hook.get_conn() as client:
         # Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
         (_, stdout, _) = client.exec_command('ls')  # pylint: disable=no-member
         assert stdout.read() is not None