예제 #1
0
    def test_ssh_connection_without_password(self, ssh_mock):
        hook = SSHHook(remote_host='remote_host',
                       port='port',
                       username='******',
                       timeout=10,
                       key_file='fake.file')

        with hook.get_conn():
            ssh_mock.return_value.connect.assert_called_once_with(
                hostname='remote_host',
                username='******',
                key_filename='fake.file',
                timeout=10,
                compress=True,
                port='port',
                sock=None)
    def execute(self, context):
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        with NamedTemporaryFile("w") as f:
            sftp_client.get(self.sftp_path, f.name)

            s3_hook.load_file(
                filename=f.name,
                key=self.s3_key,
                bucket_name=self.s3_bucket,
                replace=True
            )
예제 #3
0
파일: sftp.py 프로젝트: nsenno-dbr/airflow
    def execute(self, context: Any) -> str:
        file_msg = None
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info(
                        "ssh_hook is not provided or invalid. " "Trying ssh_conn_id to create SSHHook."
                    )
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)

            if not self.ssh_hook:
                raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id."
                )
                self.ssh_hook.remote_host = self.remote_host

            with self.ssh_hook.get_conn() as ssh_client:
                sftp_client = ssh_client.open_sftp()
                if self.operation.lower() == SFTPOperation.GET:
                    local_folder = os.path.dirname(self.local_filepath)
                    if self.create_intermediate_dirs:
                        Path(local_folder).mkdir(parents=True, exist_ok=True)
                    file_msg = f"from {self.remote_filepath} to {self.local_filepath}"
                    self.log.info("Starting to transfer %s", file_msg)
                    sftp_client.get(self.remote_filepath, self.local_filepath)
                else:
                    remote_folder = os.path.dirname(self.remote_filepath)
                    if self.create_intermediate_dirs:
                        _make_intermediate_dirs(
                            sftp_client=sftp_client,
                            remote_directory=remote_folder,
                        )
                    file_msg = f"from {self.local_filepath} to {self.remote_filepath}"
                    self.log.info("Starting to transfer file %s", file_msg)
                    sftp_client.put(self.local_filepath, self.remote_filepath, confirm=self.confirm)

        except Exception as e:
            raise AirflowException("Error while transferring {}, error: {}".format(file_msg, str(e)))

        return self.local_filepath
예제 #4
0
파일: test_ssh.py 프로젝트: xwydq/airflow
    def test_tunnel_without_password(self, ssh_mock):
        hook = SSHHook(remote_host='remote_host',
                       port='port',
                       username='******',
                       timeout=10,
                       key_file='fake.file')

        with hook.get_tunnel(1234):
            ssh_mock.assert_called_once_with('remote_host',
                                             ssh_port='port',
                                             ssh_username='******',
                                             ssh_pkey='fake.file',
                                             ssh_proxy=None,
                                             local_bind_address=('localhost', ),
                                             remote_bind_address=('localhost', 1234),
                                             host_pkey_directories=[],
                                             logger=hook.log)
예제 #5
0
    def test_ssh_connection_with_private_key_extra(self, ssh_mock):
        hook = SSHHook(
            ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
            remote_host='remote_host',
            port='port',
            username='******',
            timeout=10,
        )

        with hook.get_conn():
            ssh_mock.return_value.connect.assert_called_once_with(
                hostname='remote_host',
                username='******',
                pkey=TEST_PKEY,
                timeout=10,
                compress=True,
                port='port',
                sock=None)
예제 #6
0
파일: test_ssh.py 프로젝트: xwydq/airflow
    def test_tunnel_with_private_key(self, ssh_mock):
        hook = SSHHook(
            ssh_conn_id=self.CONN_SSH_WITH_PRIVATE_KEY_EXTRA,
            remote_host='remote_host',
            port='port',
            username='******',
            timeout=10,
        )

        with hook.get_tunnel(1234):
            ssh_mock.assert_called_once_with('remote_host',
                                             ssh_port='port',
                                             ssh_username='******',
                                             ssh_pkey=TEST_PKEY,
                                             ssh_proxy=None,
                                             local_bind_address=('localhost',),
                                             remote_bind_address=('localhost', 1234),
                                             host_pkey_directories=[],
                                             logger=hook.log)
예제 #7
0
    def get_hook(self) -> SSHHook:
        if self.ssh_conn_id:
            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
            else:
                self.log.info("ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook.")
                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout)

        if not self.ssh_hook:
            raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.log.info(
                "remote_host is provided explicitly. "
                "It will replace the remote_host which was defined "
                "in ssh_hook or predefined in connection of ssh_conn_id."
            )
            self.ssh_hook.remote_host = self.remote_host

        return self.ssh_hook
예제 #8
0
파일: test_ssh.py 프로젝트: xwydq/airflow
    def test_tunnel(self):
        hook = SSHHook(ssh_conn_id='ssh_default')

        import subprocess
        import socket

        subprocess_kwargs = dict(
            args=["python", "-c", HELLO_SERVER_CMD],
            stdout=subprocess.PIPE,
        )
        with subprocess.Popen(**subprocess_kwargs) as server_handle, hook.create_tunnel(2135, 2134):
            server_output = server_handle.stdout.read(5)
            self.assertEqual(b"ready", server_output)
            socket = socket.socket()
            socket.connect(("localhost", 2135))
            response = socket.recv(5)
            self.assertEqual(response, b"hello")
            socket.close()
            server_handle.communicate()
            self.assertEqual(server_handle.returncode, 0)
예제 #9
0
    def execute(self, context) -> None:
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        if self.use_temp_file:
            with NamedTemporaryFile("w") as f:
                sftp_client.get(self.sftp_path, f.name)

                s3_hook.load_file(filename=f.name,
                                  key=self.s3_key,
                                  bucket_name=self.s3_bucket,
                                  replace=True)
        else:
            with sftp_client.file(self.sftp_path, mode='rb') as data:
                s3_hook.get_conn().upload_fileobj(data,
                                                  self.s3_bucket,
                                                  self.s3_key,
                                                  Callback=self.log.info)
예제 #10
0
    def setUp(self):
        hook = SSHHook(ssh_conn_id='ssh_default')

        s3_hook = S3Hook('aws_default')
        hook.no_host_key_check = True
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'

        self.hook = hook
        self.s3_hook = s3_hook

        self.ssh_client = self.hook.get_conn()
        self.sftp_client = self.ssh_client.open_sftp()

        self.dag = dag
        self.s3_bucket = BUCKET
        self.sftp_path = SFTP_PATH
        self.s3_key = S3_KEY
예제 #11
0
    def setUp(self):
        from airflow.providers.ssh.hooks.ssh import SSHHook

        hook = SSHHook(ssh_conn_id='ssh_default')
        hook.no_host_key_check = True
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'
        self.hook = hook
        self.dag = dag
        self.test_dir = "/tmp"
        self.test_local_dir = "/tmp/tmp2"
        self.test_remote_dir = "/tmp/tmp1"
        self.test_local_filename = 'test_local_file'
        self.test_remote_filename = 'test_remote_file'
        self.test_local_filepath = f'{self.test_dir}/{self.test_local_filename}'
        # Local Filepath with Intermediate Directory
        self.test_local_filepath_int_dir = f'{self.test_local_dir}/{self.test_local_filename}'
        self.test_remote_filepath = f'{self.test_dir}/{self.test_remote_filename}'
        # Remote Filepath with Intermediate Directory
        self.test_remote_filepath_int_dir = f'{self.test_remote_dir}/{self.test_remote_filename}'
예제 #12
0
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: connection id from airflow Connections.
        `ssh_conn_id` will be ignored if `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param timeout: timeout (in seconds) for executing the command.
    :type timeout: int
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :type environment: dict
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :type get_pty: bool
    """

    template_fields = ('command', 'remote_host')
    template_ext = ('.sh', )

    @apply_defaults
    def __init__(self,
                 ssh_hook=None,
                 ssh_conn_id=None,
                 remote_host=None,
                 command=None,
                 timeout=10,
                 environment=None,
                 get_pty=False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.environment = environment
        self.get_pty = self.command.startswith('sudo') or get_pty

    def execute(self, context):
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. " +
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                            timeout=self.timeout)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. " +
                    "It will replace the remote_host which was defined " +
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            if not self.command:
                raise AirflowException("SSH command not specified. Aborting.")

            with self.ssh_hook.get_conn() as ssh_client:
                self.log.info("Running command: %s", self.command)

                # set timeout taken as params
                stdin, stdout, stderr = ssh_client.exec_command(
                    command=self.command,
                    get_pty=self.get_pty,
                    timeout=self.timeout,
                    environment=self.environment)
                # get channels
                channel = stdout.channel

                # closing stdin
                stdin.close()
                channel.shutdown_write()

                agg_stdout = b''
                agg_stderr = b''

                # capture any initial output in case channel is closed already
                stdout_buffer_length = len(stdout.channel.in_buffer)

                if stdout_buffer_length > 0:
                    agg_stdout += stdout.channel.recv(stdout_buffer_length)

                # read from both stdout and stderr
                while not channel.closed or \
                        channel.recv_ready() or \
                        channel.recv_stderr_ready():
                    readq, _, _ = select([channel], [], [], self.timeout)
                    for c in readq:
                        if c.recv_ready():
                            line = stdout.channel.recv(len(c.in_buffer))
                            line = line
                            agg_stdout += line
                            self.log.info(line.decode('utf-8').strip('\n'))
                        if c.recv_stderr_ready():
                            line = stderr.channel.recv_stderr(
                                len(c.in_stderr_buffer))
                            line = line
                            agg_stderr += line
                            self.log.warning(line.decode('utf-8').strip('\n'))
                    if stdout.channel.exit_status_ready()\
                            and not stderr.channel.recv_stderr_ready()\
                            and not stdout.channel.recv_ready():
                        stdout.channel.shutdown_read()
                        stdout.channel.close()
                        break

                stdout.close()
                stderr.close()

                exit_status = stdout.channel.recv_exit_status()
                if exit_status == 0:
                    enable_pickling = conf.getboolean('core',
                                                      'enable_xcom_pickling')
                    if enable_pickling:
                        return agg_stdout
                    else:
                        return b64encode(agg_stdout).decode('utf-8')

                else:
                    error_msg = agg_stderr.decode('utf-8')
                    raise AirflowException(
                        "error running cmd: {0}, error: {1}".format(
                            self.command, error_msg))

        except Exception as e:
            raise AirflowException("SSH operator error: {0}".format(str(e)))

        return True

    def tunnel(self):
        ssh_client = self.ssh_hook.get_conn()
        ssh_client.get_transport()
예제 #13
0
파일: test_ssh.py 프로젝트: xwydq/airflow
 def test_ssh_connection_old_cm(self):
     with SSHHook(ssh_conn_id='ssh_default') as hook:
         client = hook.get_conn()
         (_, stdout, _) = client.exec_command('ls')
         self.assertIsNotNone(stdout.read())
예제 #14
0
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
        from airflow Connections. `ssh_conn_id` will be ignored if
        `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds.
        Nullable. If provided, it will replace the `conn_timeout` which was
        predefined in the connection of `ssh_conn_id`.
    :type conn_timeout: int
    :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds.
    :type cmd_timeout: int
    :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds.
        Use conn_timeout and cmd_timeout parameters instead.
    :type timeout: int
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :type environment: dict
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :type get_pty: bool
    """

    template_fields = ('command', 'remote_host')
    template_ext = ('.sh', )
    template_fields_renderers = {"command": "bash"}

    def __init__(
        self,
        *,
        ssh_hook: Optional[SSHHook] = None,
        ssh_conn_id: Optional[str] = None,
        remote_host: Optional[str] = None,
        command: Optional[str] = None,
        timeout: Optional[int] = None,
        conn_timeout: Optional[int] = None,
        cmd_timeout: Optional[int] = None,
        environment: Optional[dict] = None,
        get_pty: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.conn_timeout = conn_timeout
        self.cmd_timeout = cmd_timeout
        if self.conn_timeout is None and self.timeout:
            self.conn_timeout = self.timeout
        if self.cmd_timeout is None:
            self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
        self.environment = environment
        self.get_pty = get_pty

        if self.timeout:
            warnings.warn(
                'Parameter `timeout` is deprecated.'
                'Please use `conn_timeout` and `cmd_timeout` instead.'
                'The old option `timeout` will be removed in a future version.',
                DeprecationWarning,
                stacklevel=1,
            )

    def get_hook(self) -> SSHHook:
        if self.ssh_conn_id:
            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                self.log.info(
                    "ssh_conn_id is ignored when ssh_hook is provided.")
            else:
                self.log.info(
                    "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                )
                self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                        conn_timeout=self.conn_timeout)

        if not self.ssh_hook:
            raise AirflowException(
                "Cannot operate without ssh_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.log.info(
                "remote_host is provided explicitly. "
                "It will replace the remote_host which was defined "
                "in ssh_hook or predefined in connection of ssh_conn_id.")
            self.ssh_hook.remote_host = self.remote_host

        return self.ssh_hook

    def get_ssh_client(self) -> SSHClient:
        # Remember to use context manager or call .close() on this when done
        self.log.info('Creating ssh_client')
        return self.get_hook().get_conn()

    def exec_ssh_client_command(self, ssh_client: SSHClient,
                                command: str) -> Tuple[int, bytes, bytes]:
        self.log.info("Running command: %s", command)

        # set timeout taken as params
        stdin, stdout, stderr = ssh_client.exec_command(
            command=command,
            get_pty=self.get_pty,
            timeout=self.timeout,
            environment=self.environment,
        )
        # get channels
        channel = stdout.channel

        # closing stdin
        stdin.close()
        channel.shutdown_write()

        agg_stdout = b''
        agg_stderr = b''

        # capture any initial output in case channel is closed already
        stdout_buffer_length = len(stdout.channel.in_buffer)

        if stdout_buffer_length > 0:
            agg_stdout += stdout.channel.recv(stdout_buffer_length)

        # read from both stdout and stderr
        while not channel.closed or channel.recv_ready(
        ) or channel.recv_stderr_ready():
            readq, _, _ = select([channel], [], [], self.cmd_timeout)
            for recv in readq:
                if recv.recv_ready():
                    line = stdout.channel.recv(len(recv.in_buffer))
                    agg_stdout += line
                    self.log.info(line.decode('utf-8', 'replace').strip('\n'))
                if recv.recv_stderr_ready():
                    line = stderr.channel.recv_stderr(
                        len(recv.in_stderr_buffer))
                    agg_stderr += line
                    self.log.warning(
                        line.decode('utf-8', 'replace').strip('\n'))
            if (stdout.channel.exit_status_ready()
                    and not stderr.channel.recv_stderr_ready()
                    and not stdout.channel.recv_ready()):
                stdout.channel.shutdown_read()
                try:
                    stdout.channel.close()
                except Exception:
                    # there is a race that when shutdown_read has been called and when
                    # you try to close the connection, the socket is already closed
                    # We should ignore such errors (but we should log them with warning)
                    self.log.warning("Ignoring exception on close",
                                     exc_info=True)
                break

        stdout.close()
        stderr.close()

        exit_status = stdout.channel.recv_exit_status()

        return exit_status, agg_stdout, agg_stderr

    def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
        if exit_status != 0:
            error_msg = stderr.decode('utf-8')
            raise AirflowException(
                f"error running cmd: {self.command}, error: {error_msg}")

    def run_ssh_client_command(self, ssh_client: SSHClient,
                               command: str) -> bytes:
        exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command(
            ssh_client, command)
        self.raise_for_status(exit_status, agg_stderr)
        return agg_stdout

    def execute(self, context=None) -> Union[bytes, str]:
        result: Union[bytes, str]
        if self.command is None:
            raise AirflowException(
                "SSH operator error: SSH command not specified. Aborting.")

        # Forcing get_pty to True if the command begins with "sudo".
        self.get_pty = self.command.startswith('sudo') or self.get_pty

        try:
            with self.get_ssh_client() as ssh_client:
                result = self.run_ssh_client_command(ssh_client, self.command)
        except Exception as e:
            raise AirflowException(f"SSH operator error: {str(e)}")
        enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
        if not enable_pickling:
            result = b64encode(result).decode('utf-8')
        return result

    def tunnel(self) -> None:
        """Get ssh tunnel"""
        ssh_client = self.ssh_hook.get_conn()  # type: ignore[union-attr]
        ssh_client.get_transport()
예제 #15
0
파일: test_ssh.py 프로젝트: zlisde/airflow
 def test_ssh_connection_with_host_key_extra(self, ssh_client):
     hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
     assert hook.host_key is None  # Since default no_host_key_check = True unless explicit override
     with hook.get_conn():
         assert ssh_client.return_value.connect.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.called is False
예제 #16
0
파일: test_ssh.py 프로젝트: xwydq/airflow
 def test_ssh_connection(self):
     hook = SSHHook(ssh_conn_id='ssh_default')
     with hook.get_conn() as client:
         # Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
         (_, stdout, _) = client.exec_command('ls')  # pylint: disable=no-member
         self.assertIsNotNone(stdout.read())
예제 #17
0
파일: test_ssh.py 프로젝트: xwydq/airflow
 def test_conn_with_extra_parameters(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA)
     self.assertEqual(ssh_hook.compress, True)
     self.assertEqual(ssh_hook.no_host_key_check, True)
     self.assertEqual(ssh_hook.allow_host_key_change, False)
     self.assertEqual(ssh_hook.look_for_keys, True)
예제 #18
0
파일: test_ssh.py 프로젝트: xwydq/airflow
 def test_conn_with_extra_parameters_false_look_for_keys(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS)
     self.assertEqual(ssh_hook.look_for_keys, False)
예제 #19
0
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.providers.ssh.hooks.ssh import SSHHook

#---------------
# Local app imports
#--------------

import gblevent_config as gblevent_config

#--------------
# variable setup
#--------------

sshHook = SSHHook('aws_emr')

spark_location = gblevent_config.GBLEVENT_STAGING_PREP
spark_submit_cmd = '/usr/bin/spark-submit ' + spark_location + '  {{execution_date.year}} {{execution_date.month}} {{execution_date.day}}'

start_date_val = gblevent_config.STARTDATE
end_date_val = gblevent_config.ENDDATE

start_date = datetime.strptime(start_date_val, '%Y,%m,%d')
end_date = datetime.strptime(end_date_val, '%Y,%m,%d')

#---------------
default_args = {
    'owner': 'udacity',
    'start_date': start_date,
    'end_date': end_date
예제 #20
0
    def execute(self, context):
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. " +
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                            timeout=self.timeout)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. " +
                    "It will replace the remote_host which was defined " +
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            if not self.command:
                raise AirflowException("SSH command not specified. Aborting.")

            with self.ssh_hook.get_conn() as ssh_client:
                self.log.info("Running command: %s", self.command)

                # set timeout taken as params
                stdin, stdout, stderr = ssh_client.exec_command(
                    command=self.command,
                    get_pty=self.get_pty,
                    timeout=self.timeout,
                    environment=self.environment)
                # get channels
                channel = stdout.channel

                # closing stdin
                stdin.close()
                channel.shutdown_write()

                agg_stdout = b''
                agg_stderr = b''

                # capture any initial output in case channel is closed already
                stdout_buffer_length = len(stdout.channel.in_buffer)

                if stdout_buffer_length > 0:
                    agg_stdout += stdout.channel.recv(stdout_buffer_length)

                # read from both stdout and stderr
                while not channel.closed or \
                        channel.recv_ready() or \
                        channel.recv_stderr_ready():
                    readq, _, _ = select([channel], [], [], self.timeout)
                    for c in readq:
                        if c.recv_ready():
                            line = stdout.channel.recv(len(c.in_buffer))
                            line = line
                            agg_stdout += line
                            self.log.info(line.decode('utf-8').strip('\n'))
                        if c.recv_stderr_ready():
                            line = stderr.channel.recv_stderr(
                                len(c.in_stderr_buffer))
                            line = line
                            agg_stderr += line
                            self.log.warning(line.decode('utf-8').strip('\n'))
                    if stdout.channel.exit_status_ready()\
                            and not stderr.channel.recv_stderr_ready()\
                            and not stdout.channel.recv_ready():
                        stdout.channel.shutdown_read()
                        stdout.channel.close()
                        break

                stdout.close()
                stderr.close()

                exit_status = stdout.channel.recv_exit_status()
                if exit_status == 0:
                    enable_pickling = conf.getboolean('core',
                                                      'enable_xcom_pickling')
                    if enable_pickling:
                        return agg_stdout
                    else:
                        return b64encode(agg_stdout).decode('utf-8')

                else:
                    error_msg = agg_stderr.decode('utf-8')
                    raise AirflowException(
                        "error running cmd: {0}, error: {1}".format(
                            self.command, error_msg))

        except Exception as e:
            raise AirflowException("SSH operator error: {0}".format(str(e)))

        return True
예제 #21
0
class SSHOperator(BaseOperator):
    """
    SSHOperator to execute commands on given remote host using the ssh_hook.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
        from airflow Connections. `ssh_conn_id` will be ignored if
        `ssh_hook` is provided.
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :param command: command to execute on remote host. (templated)
    :param conn_timeout: timeout (in seconds) for maintaining the connection. The default is 10 seconds.
        Nullable. If provided, it will replace the `conn_timeout` which was
        predefined in the connection of `ssh_conn_id`.
    :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds.
    :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds.
        Use conn_timeout and cmd_timeout parameters instead.
    :param environment: a dict of shell environment variables. Note that the
        server will reject them silently if `AcceptEnv` is not set in SSH config.
    :param get_pty: request a pseudo-terminal from the server. Set to ``True``
        to have the remote process killed upon task timeout.
        The default is ``False`` but note that `get_pty` is forced to ``True``
        when the `command` starts with ``sudo``.
    :param banner_timeout: timeout to wait for banner from the server in seconds
    """

    template_fields: Sequence[str] = ('command', 'remote_host')
    template_ext: Sequence[str] = ('.sh', )
    template_fields_renderers = {"command": "bash"}

    def __init__(
        self,
        *,
        ssh_hook: Optional["SSHHook"] = None,
        ssh_conn_id: Optional[str] = None,
        remote_host: Optional[str] = None,
        command: Optional[str] = None,
        timeout: Optional[int] = None,
        conn_timeout: Optional[int] = None,
        cmd_timeout: Optional[int] = None,
        environment: Optional[dict] = None,
        get_pty: bool = False,
        banner_timeout: float = 30.0,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout
        self.conn_timeout = conn_timeout
        self.cmd_timeout = cmd_timeout
        if self.conn_timeout is None and self.timeout:
            self.conn_timeout = self.timeout
        if self.cmd_timeout is None:
            self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT
        self.environment = environment
        self.get_pty = get_pty
        self.banner_timeout = banner_timeout

        if self.timeout:
            warnings.warn(
                'Parameter `timeout` is deprecated.'
                'Please use `conn_timeout` and `cmd_timeout` instead.'
                'The old option `timeout` will be removed in a future version.',
                DeprecationWarning,
                stacklevel=2,
            )

    def get_hook(self) -> "SSHHook":
        from airflow.providers.ssh.hooks.ssh import SSHHook

        if self.ssh_conn_id:
            if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                self.log.info(
                    "ssh_conn_id is ignored when ssh_hook is provided.")
            else:
                self.log.info(
                    "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                )
                self.ssh_hook = SSHHook(
                    ssh_conn_id=self.ssh_conn_id,
                    conn_timeout=self.conn_timeout,
                    banner_timeout=self.banner_timeout,
                )

        if not self.ssh_hook:
            raise AirflowException(
                "Cannot operate without ssh_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.log.info(
                "remote_host is provided explicitly. "
                "It will replace the remote_host which was defined "
                "in ssh_hook or predefined in connection of ssh_conn_id.")
            self.ssh_hook.remote_host = self.remote_host

        return self.ssh_hook

    def get_ssh_client(self) -> "SSHClient":
        # Remember to use context manager or call .close() on this when done
        self.log.info('Creating ssh_client')
        return self.get_hook().get_conn()

    def exec_ssh_client_command(self, ssh_client: "SSHClient", command: str):
        warnings.warn(
            'exec_ssh_client_command method on SSHOperator is deprecated, call '
            '`ssh_hook.exec_ssh_client_command` instead',
            DeprecationWarning,
        )
        assert self.ssh_hook
        return self.ssh_hook.exec_ssh_client_command(
            ssh_client,
            command,
            timeout=self.timeout,
            environment=self.environment,
            get_pty=self.get_pty)

    def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
        if exit_status != 0:
            raise AirflowException(
                f"SSH operator error: exit status = {exit_status}")

    def run_ssh_client_command(self, ssh_client: "SSHClient",
                               command: str) -> bytes:
        assert self.ssh_hook
        exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command(
            ssh_client,
            command,
            timeout=self.timeout,
            environment=self.environment,
            get_pty=self.get_pty)
        self.raise_for_status(exit_status, agg_stderr)
        return agg_stdout

    def execute(self, context=None) -> Union[bytes, str]:
        result: Union[bytes, str]
        if self.command is None:
            raise AirflowException(
                "SSH operator error: SSH command not specified. Aborting.")

        # Forcing get_pty to True if the command begins with "sudo".
        self.get_pty = self.command.startswith('sudo') or self.get_pty

        with self.get_ssh_client() as ssh_client:
            result = self.run_ssh_client_command(ssh_client, self.command)
        enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
        if not enable_pickling:
            result = b64encode(result).decode('utf-8')
        return result

    def tunnel(self) -> None:
        """Get ssh tunnel"""
        ssh_client = self.ssh_hook.get_conn()  # type: ignore[union-attr]
        ssh_client.get_transport()
예제 #22
0
def get_ssh_op(script):
    return SSHOperator(task_id=f'ssh_test',
                       ssh_hook=SSHHook(ssh_conn_id='ssh_conn'),
                       ssh_conn_id='operator_test',
                       retries=0,
                       command=script)
예제 #23
0
    def execute(self, context) -> Union[bytes, str, bool]:
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info(
                        "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                    )
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                            timeout=self.timeout)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            if not self.command:
                raise AirflowException("SSH command not specified. Aborting.")

            with self.ssh_hook.get_conn() as ssh_client:
                self.log.info("Running command: %s", self.command)

                # set timeout taken as params
                stdin, stdout, stderr = ssh_client.exec_command(
                    command=self.command,
                    get_pty=self.get_pty,
                    timeout=self.timeout,
                    environment=self.environment,
                )
                # get channels
                channel = stdout.channel

                # closing stdin
                stdin.close()
                channel.shutdown_write()

                agg_stdout = b''
                agg_stderr = b''

                # capture any initial output in case channel is closed already
                stdout_buffer_length = len(stdout.channel.in_buffer)

                if stdout_buffer_length > 0:
                    agg_stdout += stdout.channel.recv(stdout_buffer_length)

                # read from both stdout and stderr
                while not channel.closed or channel.recv_ready(
                ) or channel.recv_stderr_ready():
                    readq, _, _ = select([channel], [], [], self.timeout)
                    for recv in readq:
                        if recv.recv_ready():
                            line = stdout.channel.recv(len(recv.in_buffer))
                            agg_stdout += line
                            self.log.info(
                                line.decode('utf-8', 'replace').strip('\n'))
                        if recv.recv_stderr_ready():
                            line = stderr.channel.recv_stderr(
                                len(recv.in_stderr_buffer))
                            agg_stderr += line
                            self.log.warning(
                                line.decode('utf-8', 'replace').strip('\n'))
                    if (stdout.channel.exit_status_ready()
                            and not stderr.channel.recv_stderr_ready()
                            and not stdout.channel.recv_ready()):
                        stdout.channel.shutdown_read()
                        try:
                            stdout.channel.close()
                        except Exception:
                            # there is a race that when shutdown_read has been called and when
                            # you try to close the connection, the socket is already closed
                            # We should ignore such errors (but we should log them with warning)
                            self.log.warning("Ignoring exception on close",
                                             exc_info=True)
                        break

                stdout.close()
                stderr.close()

                exit_status = stdout.channel.recv_exit_status()
                if exit_status == 0:
                    enable_pickling = conf.getboolean('core',
                                                      'enable_xcom_pickling')
                    if enable_pickling:
                        return agg_stdout
                    else:
                        return b64encode(agg_stdout).decode('utf-8')

                else:
                    error_msg = agg_stderr.decode('utf-8')
                    raise AirflowException(
                        f"error running cmd: {self.command}, error: {error_msg}"
                    )

        except Exception as e:
            raise AirflowException(f"SSH operator error: {str(e)}")

        return True
예제 #24
0
class SFTPOperator(BaseOperator):
    """
    SFTPOperator for transferring files from remote host to local or vice a versa.
    This operator uses ssh_hook to open sftp transport channel that serve as basis
    for file transfer.

    :param ssh_hook: predefined ssh_hook to use for remote execution.
        Either `ssh_hook` or `ssh_conn_id` needs to be provided.
    :type ssh_hook: airflow.providers.ssh.hooks.ssh.SSHHook
    :param ssh_conn_id: connection id from airflow Connections.
        `ssh_conn_id` will be ignored if `ssh_hook` is provided.
    :type ssh_conn_id: str
    :param remote_host: remote host to connect (templated)
        Nullable. If provided, it will replace the `remote_host` which was
        defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
    :type remote_host: str
    :param local_filepath: local file path to get or put. (templated)
    :type local_filepath: str
    :param remote_filepath: remote file path to get or put. (templated)
    :type remote_filepath: str
    :param operation: specify operation 'get' or 'put', defaults to put
    :type operation: str
    :param confirm: specify if the SFTP operation should be confirmed, defaults to True
    :type confirm: bool
    :param create_intermediate_dirs: create missing intermediate directories when
        copying from remote to local and vice-versa. Default is False.

        Example: The following task would copy ``file.txt`` to the remote host
        at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
        don't exist. If the parameter is not passed it would error as the directory
        does not exist. ::

            put_file = SFTPOperator(
                task_id="test_sftp",
                ssh_conn_id="ssh_default",
                local_filepath="/tmp/file.txt",
                remote_filepath="/tmp/tmp1/tmp2/file.txt",
                operation="put",
                create_intermediate_dirs=True,
                dag=dag
            )

    :type create_intermediate_dirs: bool
    """
    template_fields = ('local_filepath', 'remote_filepath', 'remote_host')

    @apply_defaults
    def __init__(self,
                 *,
                 ssh_hook=None,
                 ssh_conn_id=None,
                 remote_host=None,
                 local_filepath=None,
                 remote_filepath=None,
                 operation=SFTPOperation.PUT,
                 confirm=True,
                 create_intermediate_dirs=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.local_filepath = local_filepath
        self.remote_filepath = remote_filepath
        self.operation = operation
        self.confirm = confirm
        self.create_intermediate_dirs = create_intermediate_dirs
        if not (self.operation.lower() == SFTPOperation.GET
                or self.operation.lower() == SFTPOperation.PUT):
            raise TypeError(
                "unsupported operation value {0}, expected {1} or {2}".format(
                    self.operation, SFTPOperation.GET, SFTPOperation.PUT))

    def execute(self, context):
        file_msg = None
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. "
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            with self.ssh_hook.get_conn() as ssh_client:
                sftp_client = ssh_client.open_sftp()
                if self.operation.lower() == SFTPOperation.GET:
                    local_folder = os.path.dirname(self.local_filepath)
                    if self.create_intermediate_dirs:
                        Path(local_folder).mkdir(parents=True, exist_ok=True)
                    file_msg = "from {0} to {1}".format(
                        self.remote_filepath, self.local_filepath)
                    self.log.info("Starting to transfer %s", file_msg)
                    sftp_client.get(self.remote_filepath, self.local_filepath)
                else:
                    remote_folder = os.path.dirname(self.remote_filepath)
                    if self.create_intermediate_dirs:
                        _make_intermediate_dirs(
                            sftp_client=sftp_client,
                            remote_directory=remote_folder,
                        )
                    file_msg = "from {0} to {1}".format(
                        self.local_filepath, self.remote_filepath)
                    self.log.info("Starting to transfer file %s", file_msg)
                    sftp_client.put(self.local_filepath,
                                    self.remote_filepath,
                                    confirm=self.confirm)

        except Exception as e:
            raise AirflowException(
                "Error while transferring {0}, error: {1}".format(
                    file_msg, str(e)))

        return self.local_filepath
예제 #25
0
 def test_conn_with_extra_parameters(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA)
     assert ssh_hook.compress is True
     assert ssh_hook.no_host_key_check is True
     assert ssh_hook.allow_host_key_change is False
     assert ssh_hook.look_for_keys is True