def test_run_command_json_output_transformed_to_object(self): hosts = ['127.0.0.1'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) results = client.run('stuff', timeout=60) self.assertTrue('127.0.0.1' in results) self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'})
def test_delete_file(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) client.delete_file('/remote/stuff') for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff')
def test_run_command(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey='~/.ssh/id_rsa', connect=True) client.run('pwd', timeout=60) expected_kwargs = {'timeout': 60} for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].run.assert_called_with( 'pwd', **expected_kwargs)
def test_delete_dir(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey='~/.ssh/id_rsa', connect=True) client.delete_dir('/remote/stuff/', force=True) expected_kwargs = {'force': True, 'timeout': None} for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].delete_dir.assert_called_with( '/remote/stuff/', **expected_kwargs)
def test_put(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey='~/.ssh/id_rsa', connect=True) client.put('/local/stuff', '/remote/stuff', mode=0744) expected_kwargs = {'mode': 0744, 'mirror_local_mode': False} for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].put.assert_called_with( '/local/stuff', '/remote/stuff', **expected_kwargs)
def test_connect_with_bastion(self): hosts = ['localhost', '127.0.0.1'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', bastion_host='testing_bastion_host', connect=False) client.connect() for host in hosts: hostname, _ = client._get_host_port_info(host) self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host')
def test_run_command(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) client.run('pwd', timeout=60) expected_kwargs = { 'timeout': 60 } for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].run.assert_called_with('pwd', **expected_kwargs)
def test_put(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) client.put('/local/stuff', '/remote/stuff', mode=0744) expected_kwargs = { 'mode': 0744, 'mirror_local_mode': False } for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].put.assert_called_with('/local/stuff', '/remote/stuff', **expected_kwargs)
def test_delete_dir(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) client.delete_dir('/remote/stuff/', force=True) expected_kwargs = { 'force': True, 'timeout': None } for host in hosts: hostname, _ = client._get_host_port_info(host) client._hosts_client[hostname].delete_dir.assert_called_with('/remote/stuff/', **expected_kwargs)
def test_host_port_info(self): client = ParallelSSHClient(hosts=['dummy'], user='******', pkey_file='~/.ssh/id_rsa', connect=True) # No port case. Port should be 22. host_str = '1.2.3.4' host, port = client._get_host_port_info(host_str) self.assertEqual(host, host_str) self.assertEqual(port, 22) # IPv6 with square brackets with port specified. host_str = '[fec2::10]:55' host, port = client._get_host_port_info(host_str) self.assertEqual(host, 'fec2::10') self.assertEqual(port, 55)
def test_connect_with_password(self): hosts = ['localhost', '127.0.0.1'] client = ParallelSSHClient(hosts=hosts, user='******', password='******', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'password': '******', 'username': '******', 'port': 22 } for host in hosts: expected_conn['hostname'] = host client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn)
def pre_run(self): LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', self.liveaction_id) hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',') self._hosts = [h.strip() for h in hosts if len(h) > 0] if len(self._hosts) < 1: raise ActionRunnerPreRunError('No hosts specified to run action for action %s.', self.liveaction_id) self._username = self.runner_parameters.get(RUNNER_USERNAME, cfg.CONF.system_user.user) self._username = self._username or cfg.CONF.system_user.user self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, 22) self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, self._ssh_key_file) self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True) self._sudo = self.runner_parameters.get(RUNNER_SUDO, False) self._sudo = self._sudo if self._sudo else False self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user) self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--') self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, FABRIC_RUNNER_DEFAULT_ACTION_TIMEOUT) LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', self.runner_id, self.liveaction_id) concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1 if concurrency > self._max_concurrency: LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency) concurrency = self._max_concurrency if self._password: self._parallel_ssh_client = ParallelSSHClient( hosts=self._hosts, user=self._username, password=self._password, port=self._ssh_port, concurrency=concurrency, raise_on_error=False, connect=True ) else: self._parallel_ssh_client = ParallelSSHClient( hosts=self._hosts, user=self._username, pkey=self._ssh_key_file, port=self._ssh_port, concurrency=concurrency, raise_on_error=False, connect=True )
def test_connect_with_password(self): hosts = ['localhost', '127.0.0.1'] client = ParallelSSHClient(hosts=hosts, user='******', password='******', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'password': '******', 'username': '******', 'timeout': 60, 'port': 22 } for host in hosts: expected_conn['hostname'] = host client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn)
def test_connect_with_key(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'key_filename': '~/.ssh/id_rsa', 'username': '******', 'port': 22 } for host in hosts: hostname, port = client._get_host_port_info(host) expected_conn['hostname'] = hostname expected_conn['port'] = port client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
def test_connect_with_random_ports(self): hosts = ['localhost:22', '127.0.0.1:55', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', password='******', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'password': '******', 'username': '******', 'port': 22 } for host in hosts: hostname, port = client._get_host_port_info(host) expected_conn['hostname'] = hostname expected_conn['port'] = port client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
def test_connect_with_key(self): hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'key_filename': '~/.ssh/id_rsa', 'username': '******', 'timeout': 60, 'port': 22 } for host in hosts: hostname, port = client._get_host_port_info(host) expected_conn['hostname'] = hostname expected_conn['port'] = port client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
def test_run_command_timeout(self): # Make sure stdout and stderr is included on timeout hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) mock_run = Mock(side_effect=SSHCommandTimeoutError( cmd='pwd', timeout=10, stdout='a', stderr='b')) for host in hosts: hostname, _ = client._get_host_port_info(host) host_client = client._hosts_client[host] host_client.run = mock_run results = client.run('pwd') for host in hosts: result = results[host] self.assertEqual(result['failed'], True) self.assertEqual(result['stdout'], 'a') self.assertEqual(result['stderr'], 'b') self.assertEqual(result['return_code'], -9)
def test_connect_with_random_ports(self): hosts = ['localhost:22', '127.0.0.1:55', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', password='******', connect=False) client.connect() expected_conn = { 'allow_agent': False, 'look_for_keys': False, 'password': '******', 'username': '******', 'port': 22 } for host in hosts: hostname, port = client._get_host_port_info(host) expected_conn['hostname'] = hostname expected_conn['port'] = port client._hosts_client[ hostname].client.connect.assert_called_once_with( **expected_conn)
def test_run_command_timeout(self): # Make sure stdout and stderr is included on timeout hosts = ['localhost', '127.0.0.1', 'st2build001'] client = ParallelSSHClient(hosts=hosts, user='******', pkey_file='~/.ssh/id_rsa', connect=True) mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10, stdout='a', stderr='b')) for host in hosts: hostname, _ = client._get_host_port_info(host) host_client = client._hosts_client[host] host_client.run = mock_run results = client.run('pwd') for host in hosts: result = results[host] self.assertEqual(result['failed'], True) self.assertEqual(result['stdout'], 'a') self.assertEqual(result['stderr'], 'b') self.assertEqual(result['return_code'], -9)
def test_cwd_used_correctly(self): remote_action = ParamikoRemoteScriptAction( 'foo-script', bson.ObjectId(), script_local_path_abs='/home/stanley/shiz_storm.py', script_local_libs_path_abs=None, named_args={}, positional_args=['blank space'], env_vars={}, on_behalf_user='******', user='******', private_key='---SOME RSA KEY---', remote_dir='/tmp', hosts=['localhost'], cwd='/test/cwd/') paramiko_runner = ParamikoRemoteScriptRunner('runner_1') paramiko_runner._parallel_ssh_client = ParallelSSHClient(['localhost'], 'stanley') paramiko_runner._run_script_on_remote_host(remote_action) exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'" ParallelSSHClient.run.assert_called_with(exp_cmd, timeout=None)
def pre_run(self): super(BaseParallelSSHRunner, self).pre_run() LOG.debug( 'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', self.liveaction_id) hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',') self._hosts = [h.strip() for h in hosts if len(h) > 0] if len(self._hosts) < 1: raise ActionRunnerPreRunError( 'No hosts specified to run action for action %s.', self.liveaction_id) self._username = self.runner_parameters.get(RUNNER_USERNAME, None) self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None) self._passphrase = self.runner_parameters.get(RUNNER_PASSPHRASE, None) if self._username: if not self._password and not self._private_key: msg = ( 'Either password or private_key data needs to be supplied for user: %s' % self._username) raise InvalidCredentialsException(msg) self._username = self._username or cfg.CONF.system_user.user self._ssh_port = self.runner_parameters.get(RUNNER_SSH_PORT, 22) self._ssh_key_file = self._private_key or self._ssh_key_file self._parallel = self.runner_parameters.get(RUNNER_PARALLEL, True) self._sudo = self.runner_parameters.get(RUNNER_SUDO, False) self._sudo = self._sudo if self._sudo else False self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user) self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--') self._timeout = self.runner_parameters.get( RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT) self._bastion_host = self.runner_parameters.get( RUNNER_BASTION_HOST, None) LOG.info( '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', self.runner_id, self.liveaction_id) concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1 if concurrency > self._max_concurrency: LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency) concurrency = self._max_concurrency client_kwargs = { 'hosts': self._hosts, 'user': self._username, 'port': self._ssh_port, 'concurrency': concurrency, 'bastion_host': self._bastion_host, 'raise_on_any_error': False, 'connect': True } if self._password: client_kwargs['password'] = self._password elif self._private_key: # Determine if the private_key is a path to the key file or the raw key material is_key_material = self._is_private_key_material( private_key=self._private_key) if is_key_material: # Raw key material client_kwargs['pkey_material'] = self._private_key else: # Assume it's a path to the key file, verify the file exists client_kwargs['pkey_file'] = self._private_key if self._passphrase: client_kwargs['passphrase'] = self._passphrase else: # Default to stanley key file specified in the config client_kwargs['pkey_file'] = self._ssh_key_file self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)