def test_ssh_proxy(self): """Test connecting to remote destination via SSH proxy client -> proxy -> destination Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" del self.client self.client = None self.server.kill() server, _ = start_server_from_ip(self.host, port=self.listen_port) proxy_host = '127.0.0.2' proxy_server, proxy_server_port = start_server_from_ip(proxy_host) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, proxy_host=proxy_host, proxy_port=proxy_server_port, ) try: output = client.run_command(self.fake_cmd) stdout = list(output[self.host]['stdout']) expected_stdout = [self.fake_resp] self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) finally: del client server.kill() proxy_server.kill()
def test_per_host_tuple_args(self): host2, host3 = '127.0.0.2', '127.0.0.3' server2, _ = start_server_from_ip(host2, port=self.listen_port) server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, host2, host3] host_args = ('arg1', 'arg2', 'arg3') cmd = 'echo %s' client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, num_retries=1) output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = [host_args[i]] stdout = list(output[host]['stdout']) self.assertEqual(expected, stdout) self.assertTrue(output[host]['exit_code'] == 0) host_args = (('arg1', 'arg2'), ('arg3', 'arg4'), ('arg5', 'arg6'),) cmd = 'echo %s %s' output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = ["%s %s" % host_args[i]] stdout = list(output[host]['stdout']) self.assertEqual(expected, stdout) self.assertTrue(output[host]['exit_code'] == 0) self.assertRaises(HostArgumentException, client.run_command, cmd, host_args=[host_args[0]]) # Invalid number of args host_args = (('arg1', ),) self.assertRaises( TypeError, client.run_command, cmd, host_args=host_args) for server in [server2, server3]: server.kill()
def test_pssh_hosts_iterator_hosts_modification(self): """Test using iterator as host list and modifying host list in place""" host2, host3 = '127.0.0.2', '127.0.0.3' server2, _ = start_server_from_ip(host2, port=self.listen_port) server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(iter(hosts), port=self.listen_port, pkey=self.user_key, pool_size=1, ) output = client.run_command(self.fake_cmd) stdout = [list(output[k]['stdout']) for k in output] expected_stdout = [[self.fake_resp], [self.fake_resp]] self.assertEqual(len(hosts), len(output), msg="Did not get output from all hosts. Got output for " \ "%s/%s hosts" % (len(output), len(hosts),)) # Run again without re-assigning host list, should do nothing output = client.run_command(self.fake_cmd) self.assertFalse(hosts[0] in output, msg="Expected no host output, got %s" % (output,)) self.assertFalse(output, msg="Expected empty output, got %s" % (output,)) # Re-assigning host list with new hosts should work hosts = ['127.0.0.2', '127.0.0.3'] client.hosts = iter(hosts) output = client.run_command(self.fake_cmd) self.assertEqual(len(hosts), len(output), msg="Did not get output from all hosts. Got output for " \ "%s/%s hosts" % (len(output), len(hosts),)) self.assertTrue(hosts[1] in output, msg="Did not get output for new host %s" % (hosts[1],)) del client server2.kill() server3.kill()
def test_per_host_dict_args(self): host2, host3 = '127.0.0.2', '127.0.0.3' server2, _ = start_server_from_ip(host2, port=self.listen_port) server3, _ = start_server_from_ip(host3, port=self.listen_port) hosts = [self.host, host2, host3] hosts_gen = (h for h in hosts) host_args = [dict(zip(('host_arg1', 'host_arg2',), ('arg1-%s' % (i,), 'arg2-%s' % (i,),))) for i, _ in enumerate(hosts)] cmd = 'echo %(host_arg1)s %(host_arg2)s' client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, num_retries=1) output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = ["%(host_arg1)s %(host_arg2)s" % host_args[i]] stdout = list(output[host]['stdout']) self.assertEqual(expected, stdout) self.assertTrue(output[host]['exit_code'] == 0) self.assertRaises(HostArgumentException, client.run_command, cmd, host_args=[host_args[0]]) # Host list generator should work also client.hosts = hosts_gen output = client.run_command(cmd, host_args=host_args) for i, host in enumerate(hosts): expected = ["%(host_arg1)s %(host_arg2)s" % host_args[i]] stdout = list(output[host]['stdout']) self.assertEqual(expected, stdout) self.assertTrue(output[host]['exit_code'] == 0) client.hosts = (h for h in hosts) self.assertRaises(HostArgumentException, client.run_command, cmd, host_args=[host_args[0]]) client.hosts = hosts
def test_proxy_remote_host_failure_timeout(self): """Test that timeout setting is passed on to proxy to be used for the proxy->remote host connection timeout """ self.server.kill() server_timeout=0.2 client_timeout=server_timeout-0.1 server, listen_port = start_server_from_ip(self.host, timeout=server_timeout) proxy_host = '127.0.0.2' proxy_server, proxy_server_port = start_server_from_ip(proxy_host) proxy_user = '******' proxy_password = '******' client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, proxy_host='127.0.0.2', proxy_port=proxy_server_port, proxy_user=proxy_user, proxy_password='******', proxy_pkey=self.user_key, num_retries=1, timeout=client_timeout, ) try: self.assertRaises( ConnectionErrorException, client.run_command, self.fake_cmd) finally: del client server.kill() proxy_server.kill()
def test_authentication_exception(self): """Test that we get authentication exception in output with correct arguments""" server, port = start_server_from_ip(self.host, fail_auth=True) hosts = [self.host] client = ParallelSSHClient(hosts, port=port, pkey=self.user_key, agent=self.agent, num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.pool.join() self.assertTrue('exception' in output[self.host], msg="Got no exception for host %s - expected connection error" % ( self.host,)) try: raise output[self.host]['exception'] except AuthenticationException as ex: self.assertEqual(ex.args[1], self.host, msg="Exception host argument is %s, should be %s" % ( ex.args[1], self.host,)) self.assertEqual(ex.args[2], port, msg="Exception port argument is %s, should be %s" % ( ex.args[2], port,)) else: raise Exception("Expected AuthenticationException") server.kill()
def test_ssh_proxy_auth(self): """Test connecting to remote destination via SSH proxy client -> proxy -> destination Proxy SSH server accepts no commands and sends no responses, only proxies to destination. Destination accepts a command as usual.""" host2 = '127.0.0.2' proxy_server, proxy_server_port = start_server_from_ip(host2) proxy_user = '******' proxy_password = '******' client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, proxy_host=host2, proxy_port=proxy_server_port, proxy_user=proxy_user, proxy_password='******', proxy_pkey=self.user_key, num_retries=1, ) expected_stdout = [self.fake_resp] try: output = client.run_command(self.fake_cmd) stdout = list(output[self.host]['stdout']) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % ( stdout, expected_stdout,)) self.assertEqual(client.host_clients[self.host].proxy_user, proxy_user) self.assertEqual(client.host_clients[self.host].proxy_password, proxy_password) self.assertTrue(client.host_clients[self.host].proxy_pkey) finally: del client proxy_server.kill()
def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" server2, _ = start_server_from_ip('127.0.0.2', port=self.listen_port, fail_auth=True) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, agent=self.agent) output = client.run_command(self.fake_cmd, stop_on_errors=False) self.assertFalse(client.finished(output)) client.join(output) self.assertTrue(client.finished(output)) self.assertTrue(hosts[0] in output, msg="Successful host does not exist in output - output is %s" % (output,)) self.assertTrue(hosts[1] in output, msg="Failed host does not exist in output - output is %s" % (output,)) self.assertTrue('exception' in output[hosts[1]], msg="Failed host %s has no exception in output - %s" % (hosts[1], output,)) try: raise output[hosts[1]]['exception'] except AuthenticationException: pass else: raise Exception("Expected AuthenticationException, got %s instead" % ( output[hosts[1]]['exception'],)) del client server2.kill()
def test_ssh_exception(self): """Test that we get ssh exception in output with correct arguments""" host = '127.0.0.10' server, port = start_server_from_ip(host, ssh_exception=True) hosts = [host] client = ParallelSSHClient(hosts, port=port, user='******', password='******', pkey=RSAKey.generate(1024), num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.pool.join() self.assertTrue('exception' in output[host], msg="Got no exception for host %s - expected connection error" % ( host,)) try: raise output[host]['exception'] except SSHException as ex: self.assertEqual(ex.args[1], host, msg="Exception host argument is %s, should be %s" % ( ex.args[1], host,)) self.assertEqual(ex.args[2], port, msg="Exception port argument is %s, should be %s" % ( ex.args[2], port,)) else: raise Exception("Expected SSHException") server.kill()
def test_pssh_client_auth_failure(self): server, listen_port = start_server_from_ip(self.host, fail_auth=True) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, agent=self.agent) self.assertRaises( AuthenticationException, client.run_command, self.fake_cmd) del client server.kill()
def test_pssh_client_ssh_exception(self): server, listen_port = start_server_from_ip(self.host, ssh_exception=True) client = ParallelSSHClient([self.host], user='******', password='******', port=listen_port, pkey=RSAKey.generate(1024), num_retries=1, ) self.assertRaises(SSHException, client.run_command, self.fake_cmd) del client server.kill()
def test_ssh_proxy_auth_fail(self): """Test failures while connecting via proxy""" proxy_host = '127.0.0.2' server, listen_port = start_server_from_ip(self.host, fail_auth=True) proxy_server, proxy_server_port = start_server_from_ip(proxy_host) proxy_user = '******' proxy_password = '******' client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, proxy_host='127.0.0.2', proxy_port=proxy_server_port, proxy_user=proxy_user, proxy_password='******', proxy_pkey=self.user_key, num_retries=1, ) try: self.assertRaises( AuthenticationException, client.run_command, self.fake_cmd) finally: del client server.kill() proxy_server.kill()
def test_ssh_proxy_target_host_failure(self): del self.client self.client = None self.server.kill() proxy_host = '127.0.0.2' proxy_server, proxy_server_port = start_server_from_ip(proxy_host) client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, proxy_host=proxy_host, proxy_port=proxy_server_port, ) try: self.assertRaises( ConnectionErrorException, client.run_command, self.fake_cmd) finally: del client proxy_server.kill()
def test_pssh_client_timeout(self): server_timeout=0.2 client_timeout=server_timeout-0.1 server, listen_port = start_server_from_ip(self.host, timeout=server_timeout) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, timeout=client_timeout, num_retries=1) output = client.run_command(self.fake_cmd, stop_on_errors=False) # Handle exception try: gevent.sleep(server_timeout+0.2) client.join(output) if not server.exception: raise Exception( "Expected gevent.Timeout from socket timeout, got none") finally: del client server.kill()
def test_host_config(self): """Test per-host configuration functionality of ParallelSSHClient""" hosts = ['127.0.0.%01d' % n for n in xrange(1,3)] host_config = dict.fromkeys(hosts) servers = [] user = '******' password = '******' for host in hosts: server, port = start_server_from_ip(host, fail_auth=hosts.index(host)) host_config[host] = {} host_config[host]['port'] = port host_config[host]['user'] = user host_config[host]['password'] = password servers.append(server) pkey_data = load_private_key(PKEY_FILENAME) host_config[hosts[0]]['private_key'] = pkey_data client = ParallelSSHClient(hosts, host_config=host_config) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.join(output) for host in hosts: self.assertTrue(host in output) try: raise output[hosts[1]]['exception'] except AuthenticationException as ex: pass else: raise AssertionError("Expected AutnenticationException on host %s", hosts[0]) self.assertFalse(output[hosts[1]]['exit_code'], msg="Execution failed on host %s" % (hosts[1],)) self.assertTrue(client.host_clients[hosts[0]].user == user, msg="Host config user override failed") self.assertTrue(client.host_clients[hosts[0]].password == password, msg="Host config password override failed") self.assertTrue(client.host_clients[hosts[0]].pkey == pkey_data, msg="Host config pkey override failed") for server in servers: server.kill()
def test_ssh_client_utf_encoding(self): """Test that unicode output works""" expected = [u'é'] cmd = u"echo 'é'" output = self.client.run_command(cmd) stdout = list(output[self.host]['stdout']) self.assertEqual(expected, stdout, msg="Got unexpected unicode output %s - expected %s" % ( stdout, expected,)) utf16_server, server_port = start_server_from_ip( self.host, encoding='utf-16') client = ParallelSSHClient([self.host], port=server_port, pkey=self.user_key) # File is already set to utf-8, cannot use utf-16 only representations # Using ascii characters decoded as utf-16 on py2 # and utf-8 encoded ascii decoded to utf-16 on py3 output = client.run_command(self.fake_cmd, encoding='utf-16') stdout = list(output[self.host]['stdout']) if type(self.fake_resp) == bytes: self.assertEqual([self.fake_resp.decode('utf-16')], stdout) else: self.assertEqual([self.fake_resp.encode('utf-8').decode('utf-16')], stdout)
def test_pssh_hosts_more_than_pool_size(self): """Test we can successfully run on more hosts than our pool size and get logs for all hosts""" # Make a second server on the same port as the first one host2 = '127.0.0.2' server2, _ = start_server_from_ip(host2, port=self.listen_port) hosts = [self.host, host2] client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, pool_size=1, ) output = client.run_command(self.fake_cmd) stdout = [list(output[k]['stdout']) for k in output] expected_stdout = [[self.fake_resp], [self.fake_resp]] self.assertEqual(len(hosts), len(output), msg="Did not get output from all hosts. Got output for " \ "%s/%s hosts" % (len(output), len(hosts),)) self.assertEqual(expected_stdout, stdout, msg="Did not get expected output from all hosts. \ Got %s - expected %s" % (stdout, expected_stdout,)) del client server2.kill()