def test_ssh_agent_authentication(self): """Test authentication via SSH agent. Do not provide public key to use when creating SSHClient, instead override the client's agent with our own fake SSH agent, add our to key to agent and try to login to server. Key should be automatically picked up from the overriden agent""" agent = FakeAgent() agent.add_key(USER_KEY) client = SSHClient(self.host, port=self.listen_port, _agent=agent) channel, host, stdout, stderr = client.exec_command(self.fake_cmd) channel.close() output = list(stdout) stderr = list(stderr) expected = [self.fake_resp] self.assertEqual(expected, output, msg="Got unexpected command output - %s" % (output,)) del client
def test_ssh_agent_authentication(self): """Test authentication via SSH agent. Do not provide public key to use when creating SSHClient, instead override the client's agent with our own fake SSH agent, add our to key to agent and try to login to server. Key should be automatically picked up from the overriden agent""" agent = FakeAgent() agent.add_key(USER_KEY) client = SSHClient(self.host, port=self.listen_port, agent=agent) channel, host, stdout, stderr = client.exec_command(self.fake_cmd) output = list(stdout) stderr = list(stderr) expected = [self.fake_resp] self.assertEqual(expected, output, msg="Got unexpected command output - %s" % (output, )) del client
class ParallelSSHClientTest(unittest.TestCase): def setUp(self): self.fake_cmd = 'echo "me"' self.fake_resp = 'me' self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,) self.user_key = USER_KEY self.host = '127.0.0.1' self.listen_socket = make_socket(self.host) self.listen_port = self.listen_socket.getsockname()[1] self.server = start_server(self.listen_socket) self.agent = FakeAgent() self.agent.add_key(USER_KEY) self.client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) def tearDown(self): del self.server del self.listen_socket del self.client def test_pssh_client_exec_command(self): cmd = self.client.exec_command(self.fake_cmd)[0] output = self.client.get_stdout(cmd) self.assertTrue(self.host in output, msg="No output for host") self.assertTrue(output[self.host]['exit_code'] == 0) def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') expected_exit_code = 1 exit_code = output[self.host]['exit_code'] self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) def test_pssh_client_exec_command_get_buffers(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) cmd = client.exec_command(self.fake_cmd)[0] output = client.get_stdout(cmd, return_buffers=True) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] exit_code = output[self.host]['exit_code'] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) def test_pssh_client_run_command_get_output(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) output = client.run_command(self.fake_cmd) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] exit_code = output[self.host]['exit_code'] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) def test_pssh_client_run_command_get_output_explicit(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) out = client.run_command(self.fake_cmd) cmds = [cmd for host in out for cmd in [out[host]['cmd']]] output = {} for cmd in cmds: client.get_output(cmd, output) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) exit_code = output[self.host]['exit_code'] self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) del client def test_pssh_client_run_long_command(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) output = client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") stdout = list(output[self.host]['stdout']) self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" % ( expected_lines, len(stdout))) del client def test_pssh_client_auth_failure(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server = start_server(listen_socket, fail_auth=True) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, agent=self.agent) cmd = client.exec_command(self.fake_cmd)[0] # Handle exception try: cmd.get() raise Exception("Expected AuthenticationException, got none") except AuthenticationException: pass del client server.join() def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" server2_socket = make_socket('127.0.0.2', port=self.listen_port) server2_port = server2_socket.getsockname()[1] server2 = start_server(server2_socket, fail_auth=True) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, agent=self.agent) output = client.run_command(self.fake_cmd, stop_on_errors=False) client.join(output) self.assertTrue(hosts[0] in output, msg="Successful host does not exist in output - output is %s" % (output,)) self.assertTrue(hosts[1] in output, msg="Failed host does not exist in output - output is %s" % (output,)) self.assertTrue('exception' in output[hosts[1]], msg="Failed host %s has no exception in output - %s" % (hosts[1], output,)) try: raise output[hosts[1]]['exception'] except AuthenticationException: pass else: raise Exception("Expected AuthenticationException, got %s instead" % ( output[hosts[1]]['exception'],)) del client server2.kill() def test_pssh_client_ssh_exception(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server = start_server(listen_socket, ssh_exception=True) client = ParallelSSHClient([self.host], user='******', password='******', port=listen_port, pkey=paramiko.RSAKey.generate(1024), ) # Handle exception try: client.run_command(self.fake_cmd) raise Exception("Expected SSHException, got none") except SSHException: pass del client server.join() def test_pssh_client_timeout(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server_timeout=0.2 client_timeout=server_timeout-0.1 server = start_server(listen_socket, timeout=server_timeout) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, timeout=client_timeout) output = client.run_command(self.fake_cmd) # Handle exception try: gevent.sleep(server_timeout+0.2) client.pool.join() if not server.exception: raise Exception( "Expected gevent.Timeout from socket timeout, got none") raise server.exception except gevent.Timeout: pass # chan_timeout = output[self.host]['channel'].gettimeout() # self.assertEqual(client_timeout, chan_timeout, # msg="Channel timeout %s does not match requested timeout %s" %( # chan_timeout, client_timeout,)) del client server.join() def test_pssh_client_exec_command_password(self): """Test password authentication. Embedded server accepts any password even empty string""" client = ParallelSSHClient([self.host], port=self.listen_port, password='') cmd = client.exec_command(self.fake_cmd)[0] output = client.get_stdout(cmd) self.assertTrue(self.host in output, msg="No output for host") self.assertTrue(output[self.host]['exit_code'] == 0, msg="Expected exit code 0, got %s" % ( output[self.host]['exit_code'],)) del client def test_pssh_client_long_running_command(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) cmd = client.exec_command(self.long_cmd(expected_lines))[0] output = client.get_stdout(cmd, return_buffers=True) self.assertTrue(self.host in output, msg="Got no output for command") stdout = list(output[self.host]['stdout']) self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" % ( expected_lines, len(stdout))) del client def test_pssh_client_long_running_command_exit_codes(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) output = client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") self.assertTrue(not output[self.host]['exit_code'], msg="Got exit code %s for still running cmd.." % ( output[self.host]['exit_code'],)) # Embedded server is also asynchronous and in the same thread # as our client so need to sleep for duration of server connection gevent.sleep(expected_lines) client.join(output) self.assertTrue(output[self.host]['exit_code'] == 0, msg="Got non-zero exit code %s" % ( output[self.host]['exit_code'],)) del client def test_pssh_client_retries(self): """Test connection error retries""" listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] expected_num_tries = 2 client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, num_retries=expected_num_tries) self.assertRaises(ConnectionErrorException, client.run_command, 'blah') try: client.run_command('blah') except ConnectionErrorException, ex: num_tries = ex.args[-1:][0] self.assertEqual(expected_num_tries, num_tries, msg="Got unexpected number of retries %s - expected %s" % (num_tries, expected_num_tries,)) else:
class ParallelSSHClientTest(unittest.TestCase): def setUp(self): self.fake_cmd = 'echo "me"' self.fake_resp = 'me' self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,) self.user_key = USER_KEY self.host = '127.0.0.1' self.listen_socket = make_socket(self.host) self.listen_port = self.listen_socket.getsockname()[1] self.server = start_server(self.listen_socket) self.agent = FakeAgent() self.agent.add_key(USER_KEY) self.client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) def tearDown(self): del self.server del self.listen_socket del self.client def test_pssh_client_exec_command(self): cmd = self.client.exec_command(self.fake_cmd)[0] output = self.client.get_stdout(cmd) self.assertTrue(self.host in output, msg="No output for host") self.assertTrue(output[self.host]['exit_code'] == 0) def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') expected_exit_code = 1 self.client.join(output) exit_code = output[self.host]['exit_code'] self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) def test_pssh_client_exec_command_get_buffers(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) cmd = client.exec_command(self.fake_cmd)[0] output = client.get_stdout(cmd, return_buffers=True) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] exit_code = output[self.host]['exit_code'] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) def test_pssh_client_run_command_get_output(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key, agent=self.agent) output = client.run_command(self.fake_cmd) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) exit_code = output[self.host]['exit_code'] self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) def test_pssh_client_run_command_get_output_explicit(self): client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) out = client.run_command(self.fake_cmd) cmds = [cmd for host in out for cmd in [out[host]['cmd']]] output = {} for cmd in cmds: client.get_output(cmd, output) expected_exit_code = 0 expected_stdout = [self.fake_resp] expected_stderr = [] stdout = list(output[self.host]['stdout']) stderr = list(output[self.host]['stderr']) exit_code = output[self.host]['exit_code'] self.assertEqual(expected_exit_code, exit_code, msg="Got unexpected exit code - %s, expected %s" % (exit_code, expected_exit_code,)) self.assertEqual(expected_stdout, stdout, msg="Got unexpected stdout - %s, expected %s" % (stdout, expected_stdout,)) self.assertEqual(expected_stderr, stderr, msg="Got unexpected stderr - %s, expected %s" % (stderr, expected_stderr,)) del client def test_pssh_client_run_long_command(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) output = client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") stdout = list(output[self.host]['stdout']) self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" % ( expected_lines, len(stdout))) del client def test_pssh_client_auth_failure(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server = start_server(listen_socket, fail_auth=True) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, agent=self.agent) cmd = client.exec_command(self.fake_cmd)[0] # Handle exception try: cmd.get() raise Exception("Expected AuthenticationException, got none") except AuthenticationException: pass del client server.kill() def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" server2_socket = make_socket('127.0.0.2', port=self.listen_port) server2_port = server2_socket.getsockname()[1] server2 = start_server(server2_socket, fail_auth=True) hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(hosts, port=self.listen_port, pkey=self.user_key, agent=self.agent) output = client.run_command(self.fake_cmd, stop_on_errors=False) self.assertFalse(client.finished(output)) client.join(output) self.assertTrue(client.finished(output)) self.assertTrue(hosts[0] in output, msg="Successful host does not exist in output - output is %s" % (output,)) self.assertTrue(hosts[1] in output, msg="Failed host does not exist in output - output is %s" % (output,)) self.assertTrue('exception' in output[hosts[1]], msg="Failed host %s has no exception in output - %s" % (hosts[1], output,)) try: raise output[hosts[1]]['exception'] except AuthenticationException: pass else: raise Exception("Expected AuthenticationException, got %s instead" % ( output[hosts[1]]['exception'],)) del client server2.kill() def test_pssh_client_ssh_exception(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server = start_server(listen_socket, ssh_exception=True) client = ParallelSSHClient([self.host], user='******', password='******', port=listen_port, pkey=paramiko.RSAKey.generate(1024), ) self.assertRaises(SSHException, client.run_command, self.fake_cmd) del client server.kill() def test_pssh_client_timeout(self): listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] server_timeout=0.2 client_timeout=server_timeout-0.1 server = start_server(listen_socket, timeout=server_timeout) client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, timeout=client_timeout) output = client.run_command(self.fake_cmd) # Handle exception try: gevent.sleep(server_timeout+0.2) client.pool.join() if not server.exception: raise Exception( "Expected gevent.Timeout from socket timeout, got none") raise server.exception except gevent.Timeout: pass # chan_timeout = output[self.host]['channel'].gettimeout() # self.assertEqual(client_timeout, chan_timeout, # msg="Channel timeout %s does not match requested timeout %s" %( # chan_timeout, client_timeout,)) del client server.kill() def test_pssh_client_exec_command_password(self): """Test password authentication. Embedded server accepts any password even empty string""" client = ParallelSSHClient([self.host], port=self.listen_port, password='') cmd = client.exec_command(self.fake_cmd)[0] output = client.get_stdout(cmd) self.assertTrue(self.host in output, msg="No output for host") self.assertTrue(output[self.host]['exit_code'] == 0, msg="Expected exit code 0, got %s" % ( output[self.host]['exit_code'],)) del client def test_pssh_client_long_running_command(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) cmd = client.exec_command(self.long_cmd(expected_lines))[0] output = client.get_stdout(cmd, return_buffers=True) self.assertTrue(self.host in output, msg="Got no output for command") stdout = list(output[self.host]['stdout']) self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" % ( expected_lines, len(stdout))) del client def test_pssh_client_long_running_command_exit_codes(self): expected_lines = 5 client = ParallelSSHClient([self.host], port=self.listen_port, pkey=self.user_key) output = client.run_command(self.long_cmd(expected_lines)) self.assertTrue(self.host in output, msg="Got no output for command") self.assertTrue(not output[self.host]['exit_code'], msg="Got exit code %s for still running cmd.." % ( output[self.host]['exit_code'],)) self.assertFalse(client.finished(output)) # Embedded server is also asynchronous and in the same thread # as our client so need to sleep for duration of server connection gevent.sleep(expected_lines) client.join(output) self.assertTrue(client.finished(output)) self.assertTrue(output[self.host]['exit_code'] == 0, msg="Got non-zero exit code %s" % ( output[self.host]['exit_code'],)) del client def test_pssh_client_retries(self): """Test connection error retries""" listen_socket = make_socket(self.host) listen_port = listen_socket.getsockname()[1] expected_num_tries = 2 client = ParallelSSHClient([self.host], port=listen_port, pkey=self.user_key, num_retries=expected_num_tries) self.assertRaises(ConnectionErrorException, client.run_command, 'blah') try: client.run_command('blah') except ConnectionErrorException, ex: num_tries = ex.args[-1:][0] self.assertEqual(expected_num_tries, num_tries, msg="Got unexpected number of retries %s - expected %s" % (num_tries, expected_num_tries,)) else: