Exemple #1
0
    def test_run_command_timeout(self):
        # Make sure stdout and stderr is included on timeout
        hosts = ["localhost", "127.0.0.1", "st2build001"]
        client = ParallelSSHClient(hosts=hosts,
                                   user="******",
                                   pkey_file="~/.ssh/id_rsa",
                                   connect=True)
        mock_run = Mock(
            side_effect=SSHCommandTimeoutError(cmd="pwd",
                                               timeout=10,
                                               stdout="a",
                                               stderr="b",
                                               ssh_connect_timeout=30))
        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            host_client = client._hosts_client[host]
            host_client.run = mock_run

        results = client.run("pwd")
        for host in hosts:
            result = results[host]
            self.assertEqual(result["failed"], True)
            self.assertEqual(result["stdout"], "a")
            self.assertEqual(result["stderr"], "b")
            self.assertEqual(result["return_code"], -9)
Exemple #2
0
    def test_run_command_timeout(self):
        # Make sure stdout and stderr is included on timeout
        hosts = ['localhost', '127.0.0.1', 'st2build001']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True)
        mock_run = Mock(
            side_effect=SSHCommandTimeoutError(cmd='pwd',
                                               timeout=10,
                                               stdout='a',
                                               stderr='b',
                                               ssh_connect_timeout=30))
        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            host_client = client._hosts_client[host]
            host_client.run = mock_run

        results = client.run('pwd')
        for host in hosts:
            result = results[host]
            self.assertEqual(result['failed'], True)
            self.assertEqual(result['stdout'], 'a')
            self.assertEqual(result['stderr'], 'b')
            self.assertEqual(result['return_code'], -9)
Exemple #3
0
 def test_run_command_json_output_transformed_to_object(self):
     hosts = ["127.0.0.1"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=True)
     results = client.run("stuff", timeout=60)
     self.assertIn("127.0.0.1", results)
     self.assertDictEqual(results["127.0.0.1"]["stdout"], {"foo": "bar"})
Exemple #4
0
 def test_run_command_json_output_transformed_to_object(self):
     hosts = ['127.0.0.1']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     results = client.run('stuff', timeout=60)
     self.assertTrue('127.0.0.1' in results)
     self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'})
Exemple #5
0
 def test_run_command_json_output_transformed_to_object(self):
     hosts = ['127.0.0.1']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     results = client.run('stuff', timeout=60)
     self.assertIn('127.0.0.1', results)
     self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'})
Exemple #6
0
 def test_delete_file(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.delete_file('/remote/stuff')
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff')
Exemple #7
0
 def test_delete_file(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.delete_file('/remote/stuff')
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff')
Exemple #8
0
 def test_delete_file(self):
     hosts = ["localhost", "127.0.0.1", "st2build001"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=True)
     client.delete_file("/remote/stuff")
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_file.assert_called_with(
             "/remote/stuff")
Exemple #9
0
 def test_delete_dir(self):
     hosts = ["localhost", "127.0.0.1", "st2build001"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=True)
     client.delete_dir("/remote/stuff/", force=True)
     expected_kwargs = {"force": True, "timeout": None}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_dir.assert_called_with(
             "/remote/stuff/", **expected_kwargs)
Exemple #10
0
 def test_put(self):
     hosts = ["localhost", "127.0.0.1", "st2build001"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=True)
     client.put("/local/stuff", "/remote/stuff", mode=0o744)
     expected_kwargs = {"mode": 0o744, "mirror_local_mode": False}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].put.assert_called_with(
             "/local/stuff", "/remote/stuff", **expected_kwargs)
Exemple #11
0
 def test_run_command(self):
     hosts = ["localhost", "127.0.0.1", "st2build001"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=True)
     client.run("pwd", timeout=60)
     expected_kwargs = {"timeout": 60, "call_line_handler_func": True}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].run.assert_called_with(
             "pwd", **expected_kwargs)
 def test_put(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.put('/local/stuff', '/remote/stuff', mode=0744)
     expected_kwargs = {'mode': 0744, 'mirror_local_mode': False}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].put.assert_called_with(
             '/local/stuff', '/remote/stuff', **expected_kwargs)
Exemple #13
0
    def test_connect_with_bastion(self):
        hosts = ['localhost', '127.0.0.1']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   bastion_host='testing_bastion_host',
                                   connect=False)
        client.connect()

        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host')
Exemple #14
0
    def test_connect_with_bastion(self):
        hosts = ['localhost', '127.0.0.1']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   bastion_host='testing_bastion_host',
                                   connect=False)
        client.connect()

        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host')
 def test_delete_dir(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.delete_dir('/remote/stuff/', force=True)
     expected_kwargs = {'force': True, 'timeout': None}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_dir.assert_called_with(
             '/remote/stuff/', **expected_kwargs)
 def test_run_command(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.run('pwd', timeout=60)
     expected_kwargs = {'timeout': 60, 'call_line_handler_func': True}
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].run.assert_called_with(
             'pwd', **expected_kwargs)
Exemple #17
0
 def test_run_command(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.run('pwd', timeout=60)
     expected_kwargs = {
         'timeout': 60
     }
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].run.assert_called_with('pwd', **expected_kwargs)
Exemple #18
0
def main(user, pkey, password, hosts_str, cmd, file_path, dir_path,
         delete_dir):
    hosts = hosts_str.split(",")
    client = ParallelSSHClient(user=user,
                               pkey_file=pkey,
                               password=password,
                               hosts=hosts)
    pp = pprint.PrettyPrinter(indent=4)

    if file_path:
        if not os.path.exists(file_path):
            raise Exception("File not found.")
        results = client.put(file_path, "/home/lakshmi/test_file", mode="0660")
        pp.pprint("Copy results: \n%s" % results)
        results = client.run("ls -rlth")
        pp.pprint("ls results: \n%s" % results)

    if dir_path:
        if not os.path.exists(dir_path):
            raise Exception("File not found.")
        results = client.put(dir_path, "/home/lakshmi/", mode="0660")
        pp.pprint("Copy results: \n%s" % results)
        results = client.run("ls -rlth")
        pp.pprint("ls results: \n%s" % results)

    if cmd:
        results = client.run(cmd)
        pp.pprint("cmd results: \n%s" % results)

    if delete_dir:
        results = client.delete_dir(delete_dir, force=True)
        pp.pprint("Delete results: \n%s" % results)
Exemple #19
0
 def test_put(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.put('/local/stuff', '/remote/stuff', mode=0o744)
     expected_kwargs = {
         'mode': 0o744,
         'mirror_local_mode': False
     }
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].put.assert_called_with('/local/stuff', '/remote/stuff',
                                                               **expected_kwargs)
Exemple #20
0
    def test_connect_with_bastion(self):
        hosts = ["localhost", "127.0.0.1"]
        client = ParallelSSHClient(
            hosts=hosts,
            user="******",
            pkey_file="~/.ssh/id_rsa",
            bastion_host="testing_bastion_host",
            connect=False,
        )
        client.connect()

        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            self.assertEqual(client._hosts_client[hostname].bastion_host,
                             "testing_bastion_host")
Exemple #21
0
 def test_delete_dir(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=True)
     client.delete_dir('/remote/stuff/', force=True)
     expected_kwargs = {
         'force': True,
         'timeout': None
     }
     for host in hosts:
         hostname, _ = client._get_host_port_info(host)
         client._hosts_client[hostname].delete_dir.assert_called_with('/remote/stuff/',
                                                                      **expected_kwargs)
Exemple #22
0
    def test_host_port_info(self):
        client = ParallelSSHClient(hosts=['dummy'],
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True)
        # No port case. Port should be 22.
        host_str = '1.2.3.4'
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, host_str)
        self.assertEqual(port, 22)

        # IPv6 with square brackets with port specified.
        host_str = '[fec2::10]:55'
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, 'fec2::10')
        self.assertEqual(port, 55)
Exemple #23
0
    def test_host_port_info(self):
        client = ParallelSSHClient(hosts=["dummy"],
                                   user="******",
                                   pkey_file="~/.ssh/id_rsa",
                                   connect=True)
        # No port case. Port should be 22.
        host_str = "1.2.3.4"
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, host_str)
        self.assertEqual(port, 22)

        # IPv6 with square brackets with port specified.
        host_str = "[fec2::10]:55"
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, "fec2::10")
        self.assertEqual(port, 55)
Exemple #24
0
    def test_host_port_info(self):
        client = ParallelSSHClient(hosts=['dummy'],
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True)
        # No port case. Port should be 22.
        host_str = '1.2.3.4'
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, host_str)
        self.assertEqual(port, 22)

        # IPv6 with square brackets with port specified.
        host_str = '[fec2::10]:55'
        host, port = client._get_host_port_info(host_str)
        self.assertEqual(host, 'fec2::10')
        self.assertEqual(port, 55)
Exemple #25
0
    def test_run_sudo_password_user_friendly_error(self):
        hosts = ['127.0.0.1']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True,
                                   sudo_password=True)
        results = client.run('stuff', timeout=60)

        expected_error = ('Failed executing command "stuff" on host "127.0.0.1" '
                          'Invalid sudo password provided or sudo is not configured for '
                          'this user (bar)')

        self.assertIn('127.0.0.1', results)
        self.assertEqual(results['127.0.0.1']['succeeded'], False)
        self.assertEqual(results['127.0.0.1']['failed'], True)
        self.assertIn(expected_error, results['127.0.0.1']['error'])
Exemple #26
0
    def test_run_sudo_password_user_friendly_error(self):
        hosts = ['127.0.0.1']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True,
                                   sudo_password=True)
        results = client.run('stuff', timeout=60)

        expected_error = ('Failed executing command "stuff" on host "127.0.0.1" '
                          'Invalid sudo password provided or sudo is not configured for '
                          'this user (bar)')

        self.assertTrue('127.0.0.1' in results)
        self.assertEqual(results['127.0.0.1']['succeeded'], False)
        self.assertEqual(results['127.0.0.1']['failed'], True)
        self.assertTrue(expected_error in results['127.0.0.1']['error'])
Exemple #27
0
 def test_connect_with_password(self):
     hosts = ['localhost', '127.0.0.1']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                password='******',
                                connect=False)
     client.connect()
     expected_conn = {
         'allow_agent': False,
         'look_for_keys': False,
         'password': '******',
         'username': '******',
         'timeout': 60,
         'port': 22
     }
     for host in hosts:
         expected_conn['hostname'] = host
         client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn)
Exemple #28
0
 def test_connect_with_password(self):
     hosts = ['localhost', '127.0.0.1']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                password='******',
                                connect=False)
     client.connect()
     expected_conn = {
         'allow_agent': False,
         'look_for_keys': False,
         'password': '******',
         'username': '******',
         'timeout': 60,
         'port': 22
     }
     for host in hosts:
         expected_conn['hostname'] = host
         client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn)
Exemple #29
0
 def test_connect_with_password(self):
     hosts = ["localhost", "127.0.0.1"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                password="******",
                                connect=False)
     client.connect()
     expected_conn = {
         "allow_agent": False,
         "look_for_keys": False,
         "password": "******",
         "username": "******",
         "timeout": 60,
         "port": 22,
     }
     for host in hosts:
         expected_conn["hostname"] = host
         client._hosts_client[host].client.connect.assert_called_once_with(
             **expected_conn)
Exemple #30
0
    def test_run_sudo_password_user_friendly_error(self):
        hosts = ["127.0.0.1"]
        client = ParallelSSHClient(
            hosts=hosts,
            user="******",
            pkey_file="~/.ssh/id_rsa",
            connect=True,
            sudo_password=True,
        )
        results = client.run("stuff", timeout=60)

        expected_error = (
            'Failed executing command "stuff" on host "127.0.0.1" '
            "Invalid sudo password provided or sudo is not configured for "
            "this user (bar)")

        self.assertIn("127.0.0.1", results)
        self.assertEqual(results["127.0.0.1"]["succeeded"], False)
        self.assertEqual(results["127.0.0.1"]["failed"], True)
        self.assertIn(expected_error, results["127.0.0.1"]["error"])
Exemple #31
0
 def test_connect_with_key(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=False)
     client.connect()
     expected_conn = {
         'allow_agent': False,
         'look_for_keys': False,
         'key_filename': '~/.ssh/id_rsa',
         'username': '******',
         'timeout': 60,
         'port': 22
     }
     for host in hosts:
         hostname, port = client._get_host_port_info(host)
         expected_conn['hostname'] = hostname
         expected_conn['port'] = port
         client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
Exemple #32
0
 def test_connect_with_key(self):
     hosts = ['localhost', '127.0.0.1', 'st2build001']
     client = ParallelSSHClient(hosts=hosts,
                                user='******',
                                pkey_file='~/.ssh/id_rsa',
                                connect=False)
     client.connect()
     expected_conn = {
         'allow_agent': False,
         'look_for_keys': False,
         'key_filename': '~/.ssh/id_rsa',
         'username': '******',
         'timeout': 60,
         'port': 22
     }
     for host in hosts:
         hostname, port = client._get_host_port_info(host)
         expected_conn['hostname'] = hostname
         expected_conn['port'] = port
         client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
Exemple #33
0
 def test_connect_with_key(self):
     hosts = ["localhost", "127.0.0.1", "st2build001"]
     client = ParallelSSHClient(hosts=hosts,
                                user="******",
                                pkey_file="~/.ssh/id_rsa",
                                connect=False)
     client.connect()
     expected_conn = {
         "allow_agent": False,
         "look_for_keys": False,
         "key_filename": "~/.ssh/id_rsa",
         "username": "******",
         "timeout": 60,
         "port": 22,
     }
     for host in hosts:
         hostname, port = client._get_host_port_info(host)
         expected_conn["hostname"] = hostname
         expected_conn["port"] = port
         client._hosts_client[
             hostname].client.connect.assert_called_once_with(
                 **expected_conn)
Exemple #34
0
    def test_run_command_timeout(self):
        # Make sure stdout and stderr is included on timeout
        hosts = ['localhost', '127.0.0.1', 'st2build001']
        client = ParallelSSHClient(hosts=hosts,
                                   user='******',
                                   pkey_file='~/.ssh/id_rsa',
                                   connect=True)
        mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10,
                                                           stdout='a',
                                                           stderr='b'))
        for host in hosts:
            hostname, _ = client._get_host_port_info(host)
            host_client = client._hosts_client[host]
            host_client.run = mock_run

        results = client.run('pwd')
        for host in hosts:
            result = results[host]
            self.assertEqual(result['failed'], True)
            self.assertEqual(result['stdout'], 'a')
            self.assertEqual(result['stderr'], 'b')
            self.assertEqual(result['return_code'], -9)
 def test_cwd_used_correctly(self):
     remote_action = ParamikoRemoteScriptAction(
         'foo-script', bson.ObjectId(),
         script_local_path_abs='/home/stanley/shiz_storm.py',
         script_local_libs_path_abs=None,
         named_args={}, positional_args=['blank space'], env_vars={},
         on_behalf_user='******', user='******',
         private_key='---SOME RSA KEY---',
         remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/'
     )
     paramiko_runner = ParamikoRemoteScriptRunner('runner_1')
     paramiko_runner._parallel_ssh_client = ParallelSSHClient(['127.0.0.1'], 'stanley')
     paramiko_runner._run_script_on_remote_host(remote_action)
     exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'"
     ParallelSSHClient.run.assert_called_with(exp_cmd,
                                              timeout=None)
 def test_cwd_used_correctly(self):
     remote_action = ParamikoRemoteScriptAction(
         "foo-script",
         bson.ObjectId(),
         script_local_path_abs="/home/stanley/shiz_storm.py",
         script_local_libs_path_abs=None,
         named_args={},
         positional_args=["blank space"],
         env_vars={},
         on_behalf_user="******",
         user="******",
         private_key="---SOME RSA KEY---",
         remote_dir="/tmp",
         hosts=["127.0.0.1"],
         cwd="/test/cwd/",
     )
     paramiko_runner = ParamikoRemoteScriptRunner("runner_1")
     paramiko_runner._parallel_ssh_client = ParallelSSHClient(["127.0.0.1"],
                                                              "stanley")
     paramiko_runner._run_script_on_remote_host(remote_action)
     exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'"
     ParallelSSHClient.run.assert_called_with(exp_cmd, timeout=None)
Exemple #37
0
class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin):

    def __init__(self, runner_id):
        super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id)
        self._hosts = None
        self._parallel = True
        self._sudo = False
        self._sudo_password = None
        self._on_behalf_user = None
        self._username = None
        self._password = None
        self._private_key = None
        self._passphrase = None
        self._kwarg_op = '--'
        self._cwd = None
        self._env = None
        self._ssh_port = None
        self._timeout = None
        self._bastion_host = None
        self._on_behalf_user = cfg.CONF.system_user.user

        self._ssh_key_file = None
        self._parallel_ssh_client = None
        self._max_concurrency = cfg.CONF.ssh_runner.max_parallel_actions

    def pre_run(self):
        super(BaseParallelSSHRunner, self).pre_run()

        LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
                  self.liveaction_id)
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
        if len(self._hosts) < 1:
            raise ActionRunnerPreRunError('No hosts specified to run action for action %s.',
                                          self.liveaction_id)
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None)
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)

        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
        self._ssh_key_file = self._private_key
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
        self._sudo = self._sudo if self._sudo else False
        self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None)

        if self.context:
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user)

        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
        self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT,
                                                   REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
        self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None)

        LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
                 self.runner_id, self.liveaction_id)

        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
        if concurrency > self._max_concurrency:
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
            concurrency = self._max_concurrency

        client_kwargs = {
            'hosts': self._hosts,
            'user': self._username,
            'port': self._ssh_port,
            'concurrency': concurrency,
            'bastion_host': self._bastion_host,
            'raise_on_any_error': False,
            'connect': True
        }

        def make_store_stdout_line_func(execution_db, action_db):
            def store_stdout_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db, action_db=action_db,
                                                data=line, output_type='stdout')

            return store_stdout_line

        def make_store_stderr_line_func(execution_db, action_db):
            def store_stderr_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db, action_db=action_db,
                                                data=line, output_type='stderr')

            return store_stderr_line

        handle_stdout_line_func = make_store_stdout_line_func(execution_db=self.execution,
                                                              action_db=self.action)
        handle_stderr_line_func = make_store_stderr_line_func(execution_db=self.execution,
                                                              action_db=self.action)

        if len(self._hosts) == 1:
            # We only support streaming output when running action on one host. That is because
            # the action output is tied to a particulat execution. User can still achieve output
            # streaming for multiple hosts by running one execution per host.
            client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func
            client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func
        else:
            LOG.debug('Real-time action output streaming is disabled, because action is running '
                      'on more than one host')

        if self._password:
            client_kwargs['password'] = self._password
        elif self._private_key:
            # Determine if the private_key is a path to the key file or the raw key material
            is_key_material = self._is_private_key_material(private_key=self._private_key)

            if is_key_material:
                # Raw key material
                client_kwargs['pkey_material'] = self._private_key
            else:
                # Assume it's a path to the key file, verify the file exists
                client_kwargs['pkey_file'] = self._private_key

            if self._passphrase:
                client_kwargs['passphrase'] = self._passphrase
        else:
            # Default to stanley key file specified in the config
            client_kwargs['pkey_file'] = self._ssh_key_file

        if self._sudo_password:
            client_kwargs['sudo_password'] = True

        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)

    def post_run(self, status, result):
        super(BaseParallelSSHRunner, self).post_run(status=status, result=result)

        # Ensure we close the connection when the action execution finishes
        if self._parallel_ssh_client:
            self._parallel_ssh_client.close()

    def _is_private_key_material(self, private_key):
        return private_key and REMOTE_RUNNER_PRIVATE_KEY_HEADER in private_key.lower()

    def _get_env_vars(self):
        """
        :rtype: ``dict``
        """
        env_vars = {}

        if self._env:
            env_vars.update(self._env)

        # Include common st2 env vars
        st2_env_vars = self._get_common_action_env_variables()
        env_vars.update(st2_env_vars)

        return env_vars

    @staticmethod
    def _get_result_status(result, allow_partial_failure):

        if 'error' in result and 'traceback' in result:
            # Assume this is a global failure where the result dictionary doesn't contain entry
            # per host
            timeout = False
            success = result.get('succeeded', False)
            status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
                                                                               timeout=timeout)
            return status

        success = not allow_partial_failure
        timeout = True

        for r in six.itervalues(result):
            r_succeess = r.get('succeeded', False) if r else False
            r_timeout = r.get('timeout', False) if r else False

            timeout &= r_timeout

            if allow_partial_failure:
                success |= r_succeess
                if success:
                    break
            else:
                success &= r_succeess
                if not success:
                    break

        status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
                                                                           timeout=timeout)

        return status

    @staticmethod
    def _get_status_for_success_and_timeout(success, timeout):
        if success:
            status = LIVEACTION_STATUS_SUCCEEDED
        elif timeout:
            # Note: Right now we only set status to timeout if all the hosts have timed out
            status = LIVEACTION_STATUS_TIMED_OUT
        else:
            status = LIVEACTION_STATUS_FAILED
        return status
Exemple #38
0
    def pre_run(self):
        super(BaseParallelSSHRunner, self).pre_run()

        LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
                  self.liveaction_id)
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
        if len(self._hosts) < 1:
            raise ActionRunnerPreRunError('No hosts specified to run action for action %s.',
                                          self.liveaction_id)
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None)
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)

        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
        self._ssh_key_file = self._private_key
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
        self._sudo = self._sudo if self._sudo else False
        self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None)

        if self.context:
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user)

        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
        self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT,
                                                   REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
        self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None)

        LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
                 self.runner_id, self.liveaction_id)

        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
        if concurrency > self._max_concurrency:
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
            concurrency = self._max_concurrency

        client_kwargs = {
            'hosts': self._hosts,
            'user': self._username,
            'port': self._ssh_port,
            'concurrency': concurrency,
            'bastion_host': self._bastion_host,
            'raise_on_any_error': False,
            'connect': True
        }

        def make_store_stdout_line_func(execution_db, action_db):
            def store_stdout_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db, action_db=action_db,
                                                data=line, output_type='stdout')

            return store_stdout_line

        def make_store_stderr_line_func(execution_db, action_db):
            def store_stderr_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db, action_db=action_db,
                                                data=line, output_type='stderr')

            return store_stderr_line

        handle_stdout_line_func = make_store_stdout_line_func(execution_db=self.execution,
                                                              action_db=self.action)
        handle_stderr_line_func = make_store_stderr_line_func(execution_db=self.execution,
                                                              action_db=self.action)

        if len(self._hosts) == 1:
            # We only support streaming output when running action on one host. That is because
            # the action output is tied to a particulat execution. User can still achieve output
            # streaming for multiple hosts by running one execution per host.
            client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func
            client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func
        else:
            LOG.debug('Real-time action output streaming is disabled, because action is running '
                      'on more than one host')

        if self._password:
            client_kwargs['password'] = self._password
        elif self._private_key:
            # Determine if the private_key is a path to the key file or the raw key material
            is_key_material = self._is_private_key_material(private_key=self._private_key)

            if is_key_material:
                # Raw key material
                client_kwargs['pkey_material'] = self._private_key
            else:
                # Assume it's a path to the key file, verify the file exists
                client_kwargs['pkey_file'] = self._private_key

            if self._passphrase:
                client_kwargs['passphrase'] = self._passphrase
        else:
            # Default to stanley key file specified in the config
            client_kwargs['pkey_file'] = self._ssh_key_file

        if self._sudo_password:
            client_kwargs['sudo_password'] = True

        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)
Exemple #39
0
    def pre_run(self):
        super(BaseParallelSSHRunner, self).pre_run()

        LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
                  self.liveaction_id)
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
        if len(self._hosts) < 1:
            raise ActionRunnerPreRunError('No hosts specified to run action for action %s.',
                                          self.liveaction_id)
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None)
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)

        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
        self._ssh_key_file = self._private_key
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
        self._sudo = self._sudo if self._sudo else False
        if self.context:
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user)
        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
        self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT,
                                                   REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
        self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None)

        LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
                 self.runner_id, self.liveaction_id)

        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
        if concurrency > self._max_concurrency:
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
            concurrency = self._max_concurrency

        client_kwargs = {
            'hosts': self._hosts,
            'user': self._username,
            'port': self._ssh_port,
            'concurrency': concurrency,
            'bastion_host': self._bastion_host,
            'raise_on_any_error': False,
            'connect': True
        }

        if self._password:
            client_kwargs['password'] = self._password
        elif self._private_key:
            # Determine if the private_key is a path to the key file or the raw key material
            is_key_material = self._is_private_key_material(private_key=self._private_key)

            if is_key_material:
                # Raw key material
                client_kwargs['pkey_material'] = self._private_key
            else:
                # Assume it's a path to the key file, verify the file exists
                client_kwargs['pkey_file'] = self._private_key

            if self._passphrase:
                client_kwargs['passphrase'] = self._passphrase
        else:
            # Default to stanley key file specified in the config
            client_kwargs['pkey_file'] = self._ssh_key_file

        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)
Exemple #40
0
    def pre_run(self):
        super(BaseParallelSSHRunner, self).pre_run()

        LOG.debug(
            'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
            self.liveaction_id)
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
        if len(self._hosts) < 1:
            raise ActionRunnerPreRunError(
                'No hosts specified to run action for action %s.',
                self.liveaction_id)
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY,
                                                       None)
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)

        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
        self._ssh_key_file = self._private_key
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
        self._sudo = self._sudo if self._sudo else False
        self._sudo_password = self.runner_parameters.get(
            RUNNER_SUDO_PASSWORD, None)

        if self.context:
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER,
                                                    self._on_behalf_user)

        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
        self._timeout = self.runner_parameters.get(
            RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
        self._bastion_host = self.runner_parameters.get(
            RUNNER_BASTION_HOST, None)

        LOG.info(
            '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
            self.runner_id, self.liveaction_id)

        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
        if concurrency > self._max_concurrency:
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
            concurrency = self._max_concurrency

        client_kwargs = {
            'hosts': self._hosts,
            'user': self._username,
            'port': self._ssh_port,
            'concurrency': concurrency,
            'bastion_host': self._bastion_host,
            'raise_on_any_error': False,
            'connect': True
        }

        def make_store_stdout_line_func(execution_db, action_db):
            def store_stdout_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db,
                                                action_db=action_db,
                                                data=line,
                                                output_type='stdout')

            return store_stdout_line

        def make_store_stderr_line_func(execution_db, action_db):
            def store_stderr_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db,
                                                action_db=action_db,
                                                data=line,
                                                output_type='stderr')

            return store_stderr_line

        handle_stdout_line_func = make_store_stdout_line_func(
            execution_db=self.execution, action_db=self.action)
        handle_stderr_line_func = make_store_stderr_line_func(
            execution_db=self.execution, action_db=self.action)

        if len(self._hosts) == 1:
            # We only support streaming output when running action on one host. That is because
            # the action output is tied to a particulat execution. User can still achieve output
            # streaming for multiple hosts by running one execution per host.
            client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func
            client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func

        if self._password:
            client_kwargs['password'] = self._password
        elif self._private_key:
            # Determine if the private_key is a path to the key file or the raw key material
            is_key_material = self._is_private_key_material(
                private_key=self._private_key)

            if is_key_material:
                # Raw key material
                client_kwargs['pkey_material'] = self._private_key
            else:
                # Assume it's a path to the key file, verify the file exists
                client_kwargs['pkey_file'] = self._private_key

            if self._passphrase:
                client_kwargs['passphrase'] = self._passphrase
        else:
            # Default to stanley key file specified in the config
            client_kwargs['pkey_file'] = self._ssh_key_file

        if self._sudo_password:
            client_kwargs['sudo_password'] = True

        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)
Exemple #41
0
class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin):
    def __init__(self, runner_id):
        super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id)
        self._hosts = None
        self._parallel = True
        self._sudo = False
        self._sudo_password = None
        self._on_behalf_user = None
        self._username = None
        self._password = None
        self._private_key = None
        self._passphrase = None
        self._kwarg_op = '--'
        self._cwd = None
        self._env = None
        self._ssh_port = None
        self._timeout = None
        self._bastion_host = None
        self._on_behalf_user = cfg.CONF.system_user.user

        self._ssh_key_file = None
        self._parallel_ssh_client = None
        self._max_concurrency = cfg.CONF.ssh_runner.max_parallel_actions

    def pre_run(self):
        super(BaseParallelSSHRunner, self).pre_run()

        LOG.debug(
            'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
            self.liveaction_id)
        hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
        self._hosts = [h.strip() for h in hosts if len(h) > 0]
        if len(self._hosts) < 1:
            raise ActionRunnerPreRunError(
                'No hosts specified to run action for action %s.',
                self.liveaction_id)
        self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
        self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
        self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY,
                                                       None)
        self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None)

        self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, None)
        self._ssh_key_file = self._private_key
        self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True)
        self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
        self._sudo = self._sudo if self._sudo else False
        self._sudo_password = self.runner_parameters.get(
            RUNNER_SUDO_PASSWORD, None)

        if self.context:
            self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER,
                                                    self._on_behalf_user)

        self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
        self._env = self.runner_parameters.get(RUNNER_ENV, {})
        self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
        self._timeout = self.runner_parameters.get(
            RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
        self._bastion_host = self.runner_parameters.get(
            RUNNER_BASTION_HOST, None)

        LOG.info(
            '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
            self.runner_id, self.liveaction_id)

        concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
        if concurrency > self._max_concurrency:
            LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
            concurrency = self._max_concurrency

        client_kwargs = {
            'hosts': self._hosts,
            'user': self._username,
            'port': self._ssh_port,
            'concurrency': concurrency,
            'bastion_host': self._bastion_host,
            'raise_on_any_error': False,
            'connect': True
        }

        def make_store_stdout_line_func(execution_db, action_db):
            def store_stdout_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db,
                                                action_db=action_db,
                                                data=line,
                                                output_type='stdout')

            return store_stdout_line

        def make_store_stderr_line_func(execution_db, action_db):
            def store_stderr_line(line):
                if cfg.CONF.actionrunner.stream_output:
                    store_execution_output_data(execution_db=execution_db,
                                                action_db=action_db,
                                                data=line,
                                                output_type='stderr')

            return store_stderr_line

        handle_stdout_line_func = make_store_stdout_line_func(
            execution_db=self.execution, action_db=self.action)
        handle_stderr_line_func = make_store_stderr_line_func(
            execution_db=self.execution, action_db=self.action)

        if len(self._hosts) == 1:
            # We only support streaming output when running action on one host. That is because
            # the action output is tied to a particulat execution. User can still achieve output
            # streaming for multiple hosts by running one execution per host.
            client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func
            client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func

        if self._password:
            client_kwargs['password'] = self._password
        elif self._private_key:
            # Determine if the private_key is a path to the key file or the raw key material
            is_key_material = self._is_private_key_material(
                private_key=self._private_key)

            if is_key_material:
                # Raw key material
                client_kwargs['pkey_material'] = self._private_key
            else:
                # Assume it's a path to the key file, verify the file exists
                client_kwargs['pkey_file'] = self._private_key

            if self._passphrase:
                client_kwargs['passphrase'] = self._passphrase
        else:
            # Default to stanley key file specified in the config
            client_kwargs['pkey_file'] = self._ssh_key_file

        if self._sudo_password:
            client_kwargs['sudo_password'] = True

        self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)

    def post_run(self, status, result):
        super(BaseParallelSSHRunner, self).post_run(status=status,
                                                    result=result)

        # Ensure we close the connection when the action execution finishes
        if self._parallel_ssh_client:
            self._parallel_ssh_client.close()

    def _is_private_key_material(self, private_key):
        return private_key and REMOTE_RUNNER_PRIVATE_KEY_HEADER in private_key.lower(
        )

    def _get_env_vars(self):
        """
        :rtype: ``dict``
        """
        env_vars = {}

        if self._env:
            env_vars.update(self._env)

        # Include common st2 env vars
        st2_env_vars = self._get_common_action_env_variables()
        env_vars.update(st2_env_vars)

        return env_vars

    @staticmethod
    def _get_result_status(result, allow_partial_failure):

        if 'error' in result and 'traceback' in result:
            # Assume this is a global failure where the result dictionary doesn't contain entry
            # per host
            timeout = False
            success = result.get('succeeded', False)
            status = BaseParallelSSHRunner._get_status_for_success_and_timeout(
                success=success, timeout=timeout)
            return status

        success = not allow_partial_failure
        timeout = True

        for r in six.itervalues(result):
            r_succeess = r.get('succeeded', False) if r else False
            r_timeout = r.get('timeout', False) if r else False

            timeout &= r_timeout

            if allow_partial_failure:
                success |= r_succeess
                if success:
                    break
            else:
                success &= r_succeess
                if not success:
                    break

        status = BaseParallelSSHRunner._get_status_for_success_and_timeout(
            success=success, timeout=timeout)

        return status

    @staticmethod
    def _get_status_for_success_and_timeout(success, timeout):
        if success:
            status = LIVEACTION_STATUS_SUCCEEDED
        elif timeout:
            # Note: Right now we only set status to timeout if all the hosts have timed out
            status = LIVEACTION_STATUS_TIMED_OUT
        else:
            status = LIVEACTION_STATUS_FAILED
        return status