def test_ssh_connection_without_password(self, ssh_mock): hook = SSHHook(remote_host='remote_host', port='port', username='******', timeout=10, key_file='fake.file') with hook.get_conn(): ssh_mock.return_value.connect.assert_called_once_with( hostname='remote_host', username='******', key_filename='fake.file', timeout=10, compress=True, port='port', sock=None)
def execute(self, context): self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.s3_conn_id) sftp_client = ssh_hook.get_conn().open_sftp() with NamedTemporaryFile("w") as f: sftp_client.get(self.sftp_path, f.name) s3_hook.load_file( filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True )
def execute(self, context: Any) -> str: file_msg = None try: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info( "ssh_hook is not provided or invalid. " "Trying ssh_conn_id to create SSHHook." ) self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id." ) self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() if self.operation.lower() == SFTPOperation.GET: local_folder = os.path.dirname(self.local_filepath) if self.create_intermediate_dirs: Path(local_folder).mkdir(parents=True, exist_ok=True) file_msg = f"from {self.remote_filepath} to {self.local_filepath}" self.log.info("Starting to transfer %s", file_msg) sftp_client.get(self.remote_filepath, self.local_filepath) else: remote_folder = os.path.dirname(self.remote_filepath) if self.create_intermediate_dirs: _make_intermediate_dirs( sftp_client=sftp_client, remote_directory=remote_folder, ) file_msg = f"from {self.local_filepath} to {self.remote_filepath}" self.log.info("Starting to transfer file %s", file_msg) sftp_client.put(self.local_filepath, self.remote_filepath, confirm=self.confirm) except Exception as e: raise AirflowException("Error while transferring {}, error: {}".format(file_msg, str(e))) return self.local_filepath
def test_tunnel_without_password(self, ssh_mock): hook = SSHHook(remote_host='remote_host', port='port', username='******', timeout=10, key_file='fake.file') with hook.get_tunnel(1234): ssh_mock.assert_called_once_with('remote_host', ssh_port='port', ssh_username='******', ssh_pkey='fake.file', ssh_proxy=None, local_bind_address=('localhost', ), remote_bind_address=('localhost', 1234), host_pkey_directories=[], logger=hook.log)
def test_ssh_connection_with_private_key_extra(self, ssh_mock): hook = SSHHook( ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA, remote_host='remote_host', port='port', username='******', timeout=10, ) with hook.get_conn(): ssh_mock.return_value.connect.assert_called_once_with( hostname='remote_host', username='******', pkey=TEST_PKEY, timeout=10, compress=True, port='port', sock=None)
def test_tunnel_with_private_key(self, ssh_mock): hook = SSHHook( ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA, remote_host='remote_host', port='port', username='******', timeout=10, ) with hook.get_tunnel(1234): ssh_mock.assert_called_once_with('remote_host', ssh_port='port', ssh_username='******', ssh_pkey=TEST_PKEY, ssh_proxy=None, local_bind_address=('localhost',), remote_bind_address=('localhost', 1234), host_pkey_directories=[], logger=hook.log)
def get_hook(self) -> SSHHook: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info("ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook.") self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id." ) self.ssh_hook.remote_host = self.remote_host return self.ssh_hook
def test_tunnel(self): hook = SSHHook(ssh_conn_id='ssh_default') import subprocess import socket subprocess_kwargs = dict( args=["python", "-c", HELLO_SERVER_CMD], stdout=subprocess.PIPE, ) with subprocess.Popen(**subprocess_kwargs) as server_handle, hook.create_tunnel(2135, 2134): server_output = server_handle.stdout.read(5) self.assertEqual(b"ready", server_output) socket = socket.socket() socket.connect(("localhost", 2135)) response = socket.recv(5) self.assertEqual(response, b"hello") socket.close() server_handle.communicate() self.assertEqual(server_handle.returncode, 0)
def execute(self, context) -> None: self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.s3_conn_id) sftp_client = ssh_hook.get_conn().open_sftp() if self.use_temp_file: with NamedTemporaryFile("w") as f: sftp_client.get(self.sftp_path, f.name) s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True) else: with sftp_client.file(self.sftp_path, mode='rb') as data: s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info)
def setUp(self): hook = SSHHook(ssh_conn_id='ssh_default') s3_hook = S3Hook('aws_default') hook.no_host_key_check = True args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' self.hook = hook self.s3_hook = s3_hook self.ssh_client = self.hook.get_conn() self.sftp_client = self.ssh_client.open_sftp() self.dag = dag self.s3_bucket = BUCKET self.sftp_path = SFTP_PATH self.s3_key = S3_KEY
def setUp(self): from airflow.providers.ssh.hooks.ssh import SSHHook hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, } dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) dag.schedule_interval = '@once' self.hook = hook self.dag = dag self.test_dir = "/tmp" self.test_local_dir = "/tmp/tmp2" self.test_remote_dir = "/tmp/tmp1" self.test_local_filename = 'test_local_file' self.test_remote_filename = 'test_remote_file' self.test_local_filepath = f'{self.test_dir}/{self.test_local_filename}' # Local Filepath with Intermediate Directory self.test_local_filepath_int_dir = f'{self.test_local_dir}/{self.test_local_filename}' self.test_remote_filepath = f'{self.test_dir}/{self.test_remote_filename}' # Remote Filepath with Intermediate Directory self.test_remote_filepath_int_dir = f'{self.test_remote_dir}/{self.test_remote_filename}'
class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. :param ssh_hook: predefined ssh_hook to use for remote execution. Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook :param ssh_conn_id: connection id from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param command: command to execute on remote host. (templated) :type command: str :param timeout: timeout (in seconds) for executing the command. :type timeout: int :param environment: a dict of shell environment variables. Note that the server will reject them silently if `AcceptEnv` is not set in SSH config. :type environment: dict :param get_pty: request a pseudo-terminal from the server. Set to ``True`` to have the remote process killed upon task timeout. The default is ``False`` but note that `get_pty` is forced to ``True`` when the `command` starts with ``sudo``. :type get_pty: bool """ template_fields = ('command', 'remote_host') template_ext = ('.sh', ) @apply_defaults def __init__(self, ssh_hook=None, ssh_conn_id=None, remote_host=None, command=None, timeout=10, environment=None, get_pty=False, *args, **kwargs): super().__init__(*args, **kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.command = command self.timeout = timeout self.environment = environment self.get_pty = self.command.startswith('sudo') or get_pty def execute(self, context): try: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info("ssh_hook is not provided or invalid. " + "Trying ssh_conn_id to create SSHHook.") self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, timeout=self.timeout) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: raise AirflowException("SSH command not specified. Aborting.") with self.ssh_hook.get_conn() as ssh_client: self.log.info("Running command: %s", self.command) # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=self.command, get_pty=self.get_pty, timeout=self.timeout, environment=self.environment) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b'' agg_stderr = b'' # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) # read from both stdout and stderr while not channel.closed or \ channel.recv_ready() or \ channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], self.timeout) for c in readq: if c.recv_ready(): line = stdout.channel.recv(len(c.in_buffer)) line = line agg_stdout += line self.log.info(line.decode('utf-8').strip('\n')) if c.recv_stderr_ready(): line = stderr.channel.recv_stderr( len(c.in_stderr_buffer)) line = line agg_stderr += line self.log.warning(line.decode('utf-8').strip('\n')) if stdout.channel.exit_status_ready()\ and not stderr.channel.recv_stderr_ready()\ and not stdout.channel.recv_ready(): stdout.channel.shutdown_read() stdout.channel.close() break stdout.close() stderr.close() exit_status = stdout.channel.recv_exit_status() if exit_status == 0: enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if enable_pickling: return agg_stdout else: return b64encode(agg_stdout).decode('utf-8') else: error_msg = agg_stderr.decode('utf-8') raise AirflowException( "error running cmd: {0}, error: {1}".format( self.command, error_msg)) except Exception as e: raise AirflowException("SSH operator error: {0}".format(str(e))) return True def tunnel(self): ssh_client = self.ssh_hook.get_conn() ssh_client.get_transport()
def test_ssh_connection_old_cm(self): with SSHHook(ssh_conn_id='ssh_default') as hook: client = hook.get_conn() (_, stdout, _) = client.exec_command('ls') self.assertIsNotNone(stdout.read())
class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. :param ssh_hook: predefined ssh_hook to use for remote execution. Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param command: command to execute on remote host. (templated) :type command: str :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds. Nullable. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :type conn_timeout: int :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. :type cmd_timeout: int :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds. Use conn_timeout and cmd_timeout parameters instead. :type timeout: int :param environment: a dict of shell environment variables. Note that the server will reject them silently if `AcceptEnv` is not set in SSH config. :type environment: dict :param get_pty: request a pseudo-terminal from the server. Set to ``True`` to have the remote process killed upon task timeout. The default is ``False`` but note that `get_pty` is forced to ``True`` when the `command` starts with ``sudo``. :type get_pty: bool """ template_fields = ('command', 'remote_host') template_ext = ('.sh', ) template_fields_renderers = {"command": "bash"} def __init__( self, *, ssh_hook: Optional[SSHHook] = None, ssh_conn_id: Optional[str] = None, remote_host: Optional[str] = None, command: Optional[str] = None, timeout: Optional[int] = None, conn_timeout: Optional[int] = None, cmd_timeout: Optional[int] = None, environment: Optional[dict] = None, get_pty: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.command = command self.timeout = timeout self.conn_timeout = conn_timeout self.cmd_timeout = cmd_timeout if self.conn_timeout is None and self.timeout: self.conn_timeout = self.timeout if self.cmd_timeout is None: self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT self.environment = environment self.get_pty = get_pty if self.timeout: warnings.warn( 'Parameter `timeout` is deprecated.' 'Please use `conn_timeout` and `cmd_timeout` instead.' 'The old option `timeout` will be removed in a future version.', DeprecationWarning, stacklevel=1, ) def get_hook(self) -> SSHHook: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info( "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." ) self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host return self.ssh_hook def get_ssh_client(self) -> SSHClient: # Remember to use context manager or call .close() on this when done self.log.info('Creating ssh_client') return self.get_hook().get_conn() def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> Tuple[int, bytes, bytes]: self.log.info("Running command: %s", command) # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=command, get_pty=self.get_pty, timeout=self.timeout, environment=self.environment, ) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b'' agg_stderr = b'' # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) # read from both stdout and stderr while not channel.closed or channel.recv_ready( ) or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], self.cmd_timeout) for recv in readq: if recv.recv_ready(): line = stdout.channel.recv(len(recv.in_buffer)) agg_stdout += line self.log.info(line.decode('utf-8', 'replace').strip('\n')) if recv.recv_stderr_ready(): line = stderr.channel.recv_stderr( len(recv.in_stderr_buffer)) agg_stderr += line self.log.warning( line.decode('utf-8', 'replace').strip('\n')) if (stdout.channel.exit_status_ready() and not stderr.channel.recv_stderr_ready() and not stdout.channel.recv_ready()): stdout.channel.shutdown_read() try: stdout.channel.close() except Exception: # there is a race that when shutdown_read has been called and when # you try to close the connection, the socket is already closed # We should ignore such errors (but we should log them with warning) self.log.warning("Ignoring exception on close", exc_info=True) break stdout.close() stderr.close() exit_status = stdout.channel.recv_exit_status() return exit_status, agg_stdout, agg_stderr def raise_for_status(self, exit_status: int, stderr: bytes) -> None: if exit_status != 0: error_msg = stderr.decode('utf-8') raise AirflowException( f"error running cmd: {self.command}, error: {error_msg}") def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> bytes: exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command( ssh_client, command) self.raise_for_status(exit_status, agg_stderr) return agg_stdout def execute(self, context=None) -> Union[bytes, str]: result: Union[bytes, str] if self.command is None: raise AirflowException( "SSH operator error: SSH command not specified. Aborting.") # Forcing get_pty to True if the command begins with "sudo". self.get_pty = self.command.startswith('sudo') or self.get_pty try: with self.get_ssh_client() as ssh_client: result = self.run_ssh_client_command(ssh_client, self.command) except Exception as e: raise AirflowException(f"SSH operator error: {str(e)}") enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if not enable_pickling: result = b64encode(result).decode('utf-8') return result def tunnel(self) -> None: """Get ssh tunnel""" ssh_client = self.ssh_hook.get_conn() # type: ignore[union-attr] ssh_client.get_transport()
def test_ssh_connection_with_host_key_extra(self, ssh_client): hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA) assert hook.host_key is None # Since default no_host_key_check = True unless explicit override with hook.get_conn(): assert ssh_client.return_value.connect.called is True assert ssh_client.return_value.get_host_keys.return_value.add.called is False
def test_ssh_connection(self): hook = SSHHook(ssh_conn_id='ssh_default') with hook.get_conn() as client: # Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437 (_, stdout, _) = client.exec_command('ls') # pylint: disable=no-member self.assertIsNotNone(stdout.read())
def test_conn_with_extra_parameters(self): ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA) self.assertEqual(ssh_hook.compress, True) self.assertEqual(ssh_hook.no_host_key_check, True) self.assertEqual(ssh_hook.allow_host_key_change, False) self.assertEqual(ssh_hook.look_for_keys, True)
def test_conn_with_extra_parameters_false_look_for_keys(self): ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS) self.assertEqual(ssh_hook.look_for_keys, False)
from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator from airflow.contrib.operators.ssh_operator import SSHOperator from airflow.providers.ssh.hooks.ssh import SSHHook #--------------- # Local app imports #-------------- import gblevent_config as gblevent_config #-------------- # variable setup #-------------- sshHook = SSHHook('aws_emr') spark_location = gblevent_config.GBLEVENT_STAGING_PREP spark_submit_cmd = '/usr/bin/spark-submit ' + spark_location + ' {{execution_date.year}} {{execution_date.month}} {{execution_date.day}}' start_date_val = gblevent_config.STARTDATE end_date_val = gblevent_config.ENDDATE start_date = datetime.strptime(start_date_val, '%Y,%m,%d') end_date = datetime.strptime(end_date_val, '%Y,%m,%d') #--------------- default_args = { 'owner': 'udacity', 'start_date': start_date, 'end_date': end_date
def execute(self, context): try: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info("ssh_hook is not provided or invalid. " + "Trying ssh_conn_id to create SSHHook.") self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, timeout=self.timeout) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: raise AirflowException("SSH command not specified. Aborting.") with self.ssh_hook.get_conn() as ssh_client: self.log.info("Running command: %s", self.command) # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=self.command, get_pty=self.get_pty, timeout=self.timeout, environment=self.environment) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b'' agg_stderr = b'' # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) # read from both stdout and stderr while not channel.closed or \ channel.recv_ready() or \ channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], self.timeout) for c in readq: if c.recv_ready(): line = stdout.channel.recv(len(c.in_buffer)) line = line agg_stdout += line self.log.info(line.decode('utf-8').strip('\n')) if c.recv_stderr_ready(): line = stderr.channel.recv_stderr( len(c.in_stderr_buffer)) line = line agg_stderr += line self.log.warning(line.decode('utf-8').strip('\n')) if stdout.channel.exit_status_ready()\ and not stderr.channel.recv_stderr_ready()\ and not stdout.channel.recv_ready(): stdout.channel.shutdown_read() stdout.channel.close() break stdout.close() stderr.close() exit_status = stdout.channel.recv_exit_status() if exit_status == 0: enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if enable_pickling: return agg_stdout else: return b64encode(agg_stdout).decode('utf-8') else: error_msg = agg_stderr.decode('utf-8') raise AirflowException( "error running cmd: {0}, error: {1}".format( self.command, error_msg)) except Exception as e: raise AirflowException("SSH operator error: {0}".format(str(e))) return True
class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. :param ssh_hook: predefined ssh_hook to use for remote execution. Either `ssh_hook` or `ssh_conn_id` needs to be provided. :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook` is provided. :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :param command: command to execute on remote host. (templated) :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds. Nullable. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds. Use conn_timeout and cmd_timeout parameters instead. :param environment: a dict of shell environment variables. Note that the server will reject them silently if `AcceptEnv` is not set in SSH config. :param get_pty: request a pseudo-terminal from the server. Set to ``True`` to have the remote process killed upon task timeout. The default is ``False`` but note that `get_pty` is forced to ``True`` when the `command` starts with ``sudo``. :param banner_timeout: timeout to wait for banner from the server in seconds """ template_fields: Sequence[str] = ('command', 'remote_host') template_ext: Sequence[str] = ('.sh', ) template_fields_renderers = {"command": "bash"} def __init__( self, *, ssh_hook: Optional["SSHHook"] = None, ssh_conn_id: Optional[str] = None, remote_host: Optional[str] = None, command: Optional[str] = None, timeout: Optional[int] = None, conn_timeout: Optional[int] = None, cmd_timeout: Optional[int] = None, environment: Optional[dict] = None, get_pty: bool = False, banner_timeout: float = 30.0, **kwargs, ) -> None: super().__init__(**kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.command = command self.timeout = timeout self.conn_timeout = conn_timeout self.cmd_timeout = cmd_timeout if self.conn_timeout is None and self.timeout: self.conn_timeout = self.timeout if self.cmd_timeout is None: self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT self.environment = environment self.get_pty = get_pty self.banner_timeout = banner_timeout if self.timeout: warnings.warn( 'Parameter `timeout` is deprecated.' 'Please use `conn_timeout` and `cmd_timeout` instead.' 'The old option `timeout` will be removed in a future version.', DeprecationWarning, stacklevel=2, ) def get_hook(self) -> "SSHHook": from airflow.providers.ssh.hooks.ssh import SSHHook if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info( "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." ) self.ssh_hook = SSHHook( ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout, banner_timeout=self.banner_timeout, ) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host return self.ssh_hook def get_ssh_client(self) -> "SSHClient": # Remember to use context manager or call .close() on this when done self.log.info('Creating ssh_client') return self.get_hook().get_conn() def exec_ssh_client_command(self, ssh_client: "SSHClient", command: str): warnings.warn( 'exec_ssh_client_command method on SSHOperator is deprecated, call ' '`ssh_hook.exec_ssh_client_command` instead', DeprecationWarning, ) assert self.ssh_hook return self.ssh_hook.exec_ssh_client_command( ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty) def raise_for_status(self, exit_status: int, stderr: bytes) -> None: if exit_status != 0: raise AirflowException( f"SSH operator error: exit status = {exit_status}") def run_ssh_client_command(self, ssh_client: "SSHClient", command: str) -> bytes: assert self.ssh_hook exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command( ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty) self.raise_for_status(exit_status, agg_stderr) return agg_stdout def execute(self, context=None) -> Union[bytes, str]: result: Union[bytes, str] if self.command is None: raise AirflowException( "SSH operator error: SSH command not specified. Aborting.") # Forcing get_pty to True if the command begins with "sudo". self.get_pty = self.command.startswith('sudo') or self.get_pty with self.get_ssh_client() as ssh_client: result = self.run_ssh_client_command(ssh_client, self.command) enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if not enable_pickling: result = b64encode(result).decode('utf-8') return result def tunnel(self) -> None: """Get ssh tunnel""" ssh_client = self.ssh_hook.get_conn() # type: ignore[union-attr] ssh_client.get_transport()
def get_ssh_op(script): return SSHOperator(task_id=f'ssh_test', ssh_hook=SSHHook(ssh_conn_id='ssh_conn'), ssh_conn_id='operator_test', retries=0, command=script)
def execute(self, context) -> Union[bytes, str, bool]: try: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info( "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." ) self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, timeout=self.timeout) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: raise AirflowException("SSH command not specified. Aborting.") with self.ssh_hook.get_conn() as ssh_client: self.log.info("Running command: %s", self.command) # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=self.command, get_pty=self.get_pty, timeout=self.timeout, environment=self.environment, ) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b'' agg_stderr = b'' # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) # read from both stdout and stderr while not channel.closed or channel.recv_ready( ) or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], self.timeout) for recv in readq: if recv.recv_ready(): line = stdout.channel.recv(len(recv.in_buffer)) agg_stdout += line self.log.info( line.decode('utf-8', 'replace').strip('\n')) if recv.recv_stderr_ready(): line = stderr.channel.recv_stderr( len(recv.in_stderr_buffer)) agg_stderr += line self.log.warning( line.decode('utf-8', 'replace').strip('\n')) if (stdout.channel.exit_status_ready() and not stderr.channel.recv_stderr_ready() and not stdout.channel.recv_ready()): stdout.channel.shutdown_read() try: stdout.channel.close() except Exception: # there is a race that when shutdown_read has been called and when # you try to close the connection, the socket is already closed # We should ignore such errors (but we should log them with warning) self.log.warning("Ignoring exception on close", exc_info=True) break stdout.close() stderr.close() exit_status = stdout.channel.recv_exit_status() if exit_status == 0: enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if enable_pickling: return agg_stdout else: return b64encode(agg_stdout).decode('utf-8') else: error_msg = agg_stderr.decode('utf-8') raise AirflowException( f"error running cmd: {self.command}, error: {error_msg}" ) except Exception as e: raise AirflowException(f"SSH operator error: {str(e)}") return True
class SFTPOperator(BaseOperator): """ SFTPOperator for transferring files from remote host to local or vice a versa. This operator uses ssh_hook to open sftp transport channel that serve as basis for file transfer. :param ssh_hook: predefined ssh_hook to use for remote execution. Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook :param ssh_conn_id: connection id from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param local_filepath: local file path to get or put. (templated) :type local_filepath: str :param remote_filepath: remote file path to get or put. (templated) :type remote_filepath: str :param operation: specify operation 'get' or 'put', defaults to put :type operation: str :param confirm: specify if the SFTP operation should be confirmed, defaults to True :type confirm: bool :param create_intermediate_dirs: create missing intermediate directories when copying from remote to local and vice-versa. Default is False. Example: The following task would copy ``file.txt`` to the remote host at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they don't exist. If the parameter is not passed it would error as the directory does not exist. :: put_file = SFTPOperator( task_id="test_sftp", ssh_conn_id="ssh_default", local_filepath="/tmp/file.txt", remote_filepath="/tmp/tmp1/tmp2/file.txt", operation="put", create_intermediate_dirs=True, dag=dag ) :type create_intermediate_dirs: bool """ template_fields = ('local_filepath', 'remote_filepath', 'remote_host') @apply_defaults def __init__(self, *, ssh_hook=None, ssh_conn_id=None, remote_host=None, local_filepath=None, remote_filepath=None, operation=SFTPOperation.PUT, confirm=True, create_intermediate_dirs=False, **kwargs): super().__init__(**kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.local_filepath = local_filepath self.remote_filepath = remote_filepath self.operation = operation self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): raise TypeError( "unsupported operation value {0}, expected {1} or {2}".format( self.operation, SFTPOperation.GET, SFTPOperation.PUT)) def execute(self, context): file_msg = None try: if self.ssh_conn_id: if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info( "ssh_conn_id is ignored when ssh_hook is provided.") else: self.log.info("ssh_hook is not provided or invalid. " "Trying ssh_conn_id to create SSHHook.") self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: raise AirflowException( "Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() if self.operation.lower() == SFTPOperation.GET: local_folder = os.path.dirname(self.local_filepath) if self.create_intermediate_dirs: Path(local_folder).mkdir(parents=True, exist_ok=True) file_msg = "from {0} to {1}".format( self.remote_filepath, self.local_filepath) self.log.info("Starting to transfer %s", file_msg) sftp_client.get(self.remote_filepath, self.local_filepath) else: remote_folder = os.path.dirname(self.remote_filepath) if self.create_intermediate_dirs: _make_intermediate_dirs( sftp_client=sftp_client, remote_directory=remote_folder, ) file_msg = "from {0} to {1}".format( self.local_filepath, self.remote_filepath) self.log.info("Starting to transfer file %s", file_msg) sftp_client.put(self.local_filepath, self.remote_filepath, confirm=self.confirm) except Exception as e: raise AirflowException( "Error while transferring {0}, error: {1}".format( file_msg, str(e))) return self.local_filepath
def test_conn_with_extra_parameters(self): ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA) assert ssh_hook.compress is True assert ssh_hook.no_host_key_check is True assert ssh_hook.allow_host_key_change is False assert ssh_hook.look_for_keys is True