Esempio n. 1
0
 def test_conn_with_extra_parameters(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA)
     self.assertEqual(ssh_hook.compress, True)
     self.assertEqual(ssh_hook.no_host_key_check, True)
     self.assertEqual(ssh_hook.allow_host_key_change, False)
     self.assertEqual(ssh_hook.look_for_keys, True)
Esempio n. 2
0
 def test_conn_with_extra_parameters_false_look_for_keys(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS)
     self.assertEqual(ssh_hook.look_for_keys, False)
Esempio n. 3
0
 def test_ssh_connection(self):
     hook = SSHHook(ssh_conn_id='ssh_default')
     with hook.get_conn() as client:
         # Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
         (_, stdout, _) = client.exec_command('ls')  # pylint: disable=no-member
         self.assertIsNotNone(stdout.read())
Esempio n. 4
0
 def test_ssh_connection_old_cm(self):
     with SSHHook(ssh_conn_id='ssh_default') as hook:
         client = hook.get_conn()
         (_, stdout, _) = client.exec_command('ls')
         self.assertIsNotNone(stdout.read())
Esempio n. 5
0
    def execute(self, context):
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info(
                        "ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info("ssh_hook is not provided or invalid. "
                                  "Trying ssh_conn_id to create SSHHook.")
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
                                            timeout=self.timeout)

            if not self.ssh_hook:
                raise AirflowException(
                    "Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id.")
                self.ssh_hook.remote_host = self.remote_host

            if not self.command:
                raise AirflowException("SSH command not specified. Aborting.")

            with self.ssh_hook.get_conn() as ssh_client:
                self.log.info("Running command: %s", self.command)

                # set timeout taken as params
                stdin, stdout, stderr = ssh_client.exec_command(
                    command=self.command,
                    get_pty=self.get_pty,
                    timeout=self.timeout,
                    environment=self.environment)
                # get channels
                channel = stdout.channel

                # closing stdin
                stdin.close()
                channel.shutdown_write()

                agg_stdout = b''
                agg_stderr = b''

                # capture any initial output in case channel is closed already
                stdout_buffer_length = len(stdout.channel.in_buffer)

                if stdout_buffer_length > 0:
                    agg_stdout += stdout.channel.recv(stdout_buffer_length)

                # read from both stdout and stderr
                while not channel.closed or \
                        channel.recv_ready() or \
                        channel.recv_stderr_ready():
                    readq, _, _ = select([channel], [], [], self.timeout)
                    for recv in readq:
                        if recv.recv_ready():
                            line = stdout.channel.recv(len(recv.in_buffer))
                            agg_stdout += line
                            self.log.info(line.decode('utf-8').strip('\n'))
                        if recv.recv_stderr_ready():
                            line = stderr.channel.recv_stderr(
                                len(recv.in_stderr_buffer))
                            agg_stderr += line
                            self.log.warning(line.decode('utf-8').strip('\n'))
                    if stdout.channel.exit_status_ready()\
                            and not stderr.channel.recv_stderr_ready()\
                            and not stdout.channel.recv_ready():
                        stdout.channel.shutdown_read()
                        stdout.channel.close()
                        break

                stdout.close()
                stderr.close()

                exit_status = stdout.channel.recv_exit_status()
                if exit_status == 0:
                    enable_pickling = conf.getboolean('core',
                                                      'enable_xcom_pickling')
                    if enable_pickling:
                        return agg_stdout
                    else:
                        return b64encode(agg_stdout).decode('utf-8')

                else:
                    error_msg = agg_stderr.decode('utf-8')
                    raise AirflowException(
                        "error running cmd: {0}, error: {1}".format(
                            self.command, error_msg))

        except Exception as e:
            raise AirflowException("SSH operator error: {0}".format(str(e)))

        return True
Esempio n. 6
0
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.providers.ssh.hooks.ssh import SSHHook

#---------------
# Local app imports
#--------------

import gblevent_config as gblevent_config

#--------------
# variable setup
#--------------

sshHook = SSHHook('aws_emr')

spark_location = gblevent_config.GBLEVENT_STAGING_PREP
spark_submit_cmd = '/usr/bin/spark-submit ' + spark_location + '  {{execution_date.year}} {{execution_date.month}} {{execution_date.day}}'

start_date_val = gblevent_config.STARTDATE
end_date_val = gblevent_config.ENDDATE

start_date = datetime.strptime(start_date_val, '%Y,%m,%d')
end_date = datetime.strptime(end_date_val, '%Y,%m,%d')

#---------------
default_args = {
    'owner': 'udacity',
    'start_date': start_date,
    'end_date': end_date
Esempio n. 7
0
 def test_ssh_connection_with_host_key_extra(self, ssh_client):
     hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_EXTRA)
     assert hook.host_key is None  # Since default no_host_key_check = True unless explicit override
     with hook.get_conn():
         assert ssh_client.return_value.connect.called is True
         assert ssh_client.return_value.get_host_keys.return_value.add.called is False
Esempio n. 8
0
 def test_conn_with_extra_parameters(self):
     ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA)
     assert ssh_hook.compress is True
     assert ssh_hook.no_host_key_check is True
     assert ssh_hook.allow_host_key_change is False
     assert ssh_hook.look_for_keys is True
Esempio n. 9
0
def get_ssh_op(script):
    return SSHOperator(task_id=f'ssh_test',
                       ssh_hook=SSHHook(ssh_conn_id='ssh_conn'),
                       ssh_conn_id='operator_test',
                       retries=0,
                       command=script)