def test_connection_error_exception(self): """Test that we get connection error exception in output with correct arguments""" # Make port with no server listening on it on separate ip host = '127.0.0.3' port = self.make_random_port() hosts = [host] client = ParallelSSHClient(hosts, port=port, pkey=self.user_key, num_retries=1) output = client.run_command(self.cmd, stop_on_errors=False) client.join(output) self.assertTrue( 'exception' in output[host], msg="Got no exception for host %s - expected connection error" % (host, )) try: raise output[host]['exception'] except ConnectionErrorException as ex: self.assertEqual( ex.host, host, msg="Exception host argument is %s, should be %s" % ( ex.host, host, )) self.assertEqual( ex.args[2], port, msg="Exception port argument is %s, should be %s" % ( ex.args[2], port, )) else: raise Exception("Expected ConnectionErrorException")
def test_connection_timeout(self): client_timeout = .01 host = 'fakehost.com' client = ParallelSSHClient([host], port=self.port, pkey=self.user_key, timeout=client_timeout, num_retries=1) output = client.run_command('sleep 1', stop_on_errors=False) self.assertIsInstance(output[0].exception, ConnectionErrorException)
def test_cert_auth(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.cert_pkey, cert_file=self.cert_file) output = client.run_command(self.cmd) client.join(output) resp = list(output[0].stdout) self.assertListEqual(resp, [self.resp])
def test_read_timeout(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) output = client.run_command('sleep 2; echo me; echo me; echo me', timeout=1) for host_out in output: self.assertRaises(Timeout, list, host_out.stdout) self.assertFalse(output[0].channel.is_eof()) client.join(output) for host_out in output: stdout = list(host_out.stdout) self.assertEqual(len(stdout), 3) self.assertTrue(output[0].channel.is_eof())
def test_pssh_client_timeout(self): # 1ms timeout client_timeout = 0.00001 client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key, timeout=client_timeout, num_retries=1) now = datetime.now() output = client.run_command('sleep 1', stop_on_errors=False) dt = datetime.now() - now pssh_logger.debug("Run command took %s", dt) self.assertIsInstance(output[self.host].exception, Timeout)
def test_get_last_output(self): host = '127.0.0.9' server = OpenSSHServer(listen_ip=host, port=self.port) server.start_server() try: hosts = [self.host, host] client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key) self.assertTrue(client.cmds is None) self.assertTrue(client.get_last_output() is None) client.run_command(self.cmd) self.assertTrue(client.cmds is not None) self.assertEqual(len(client.cmds), len(hosts)) expected_stdout = [self.resp] expected_stderr = [] output = client.get_last_output() self.assertIsInstance(output, list) self.assertEqual(len(output), len(hosts)) self.assertIsInstance(output[0], HostOutput) client.join(output) for i, host in enumerate(hosts): self.assertEqual(output[i].host, host) exit_code = output[i].exit_code _stdout = list(output[i].stdout) _stderr = list(output[i].stderr) self.assertEqual(exit_code, 0) self.assertListEqual(expected_stdout, _stdout) self.assertListEqual(expected_stderr, _stderr) finally: server.stop()
def test_timeout_on_open_session(self): timeout = 1 client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key, timeout=timeout, num_retries=1) def _session(timeout=1): sleep(timeout + 1) joinall(client.connect_auth()) sleep(.01) client._host_clients[(0, self.host)].open_session = _session self.assertRaises(Timeout, client.run_command, self.cmd)
def test_agent_auth(self): client = ParallelSSHClient( [self.host], port=self.port, num_retries=1, pkey=None, allow_agent=True, identity_auth=True) self.assertRaises(AuthenticationException, client.run_command, self.cmd)
def test_pssh_client_auth_failure(self): client = ParallelSSHClient([self.host], port=self.port, user='******', pkey=self.user_key, num_retries=1) self.assertRaises( AuthenticationException, client.run_command, self.cmd)
def setUpClass(cls): _mask = int('0600') if version_info <= (2, ) else 0o600 for _file in [PKEY_FILENAME, USER_CERT_PRIV_KEY, CA_USER_KEY]: os.chmod(_file, _mask) sign_cert() cls.host = '127.0.0.1' cls.port = 2422 cls.server = OpenSSHServer(listen_ip=cls.host, port=cls.port) cls.server.start_server() cls.cmd = 'echo me' cls.resp = u'me' cls.user_key = PKEY_FILENAME cls.user_pub_key = PUB_FILE cls.cert_pkey = USER_CERT_PRIV_KEY cls.cert_file = USER_CERT_FILE cls.user = USER # Single client for all tests ensures that the client does not do # anything that causes server to disconnect the session and # affect all subsequent uses of the same session. cls.client = ParallelSSHClient( [cls.host], pkey=PKEY_FILENAME, port=cls.port, num_retries=1, retry_delay=.1, )
def test_gssapi_auth(self): _server_id = 'server_id' _client_id = 'client_id' client = ParallelSSHClient( [self.host], port=self.port, num_retries=1, pkey=None, gssapi_server_identity=_server_id, gssapi_client_identity=_client_id, gssapi_delegate_credentials=True, identity_auth=False) self.assertRaises(AuthenticationException, client.run_command, self.cmd) client = ParallelSSHClient( [self.host], port=self.port, num_retries=1, pkey=None, gssapi_auth=True, identity_auth=False) self.assertRaises(AuthenticationException, client.run_command, self.cmd)
def test_pssh_client_hosts_list_part_failure(self): """Test getting output for remainder of host list in the case where one host in the host list has a failure""" hosts = [self.host, '127.1.1.100'] client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key, num_retries=1) output = client.run_command(self.cmd, stop_on_errors=False) self.assertFalse(client.finished(output)) client.join(output, consume_output=True) self.assertTrue(client.finished(output)) self.assertTrue( hosts[0] in output, msg="Successful host does not exist in output - output is %s" % (output, )) self.assertTrue( hosts[1] in output, msg="Failed host does not exist in output - output is %s" % (output, )) self.assertTrue('exception' in output[hosts[1]], msg="Failed host %s has no exception in output - %s" % ( hosts[1], output, )) self.assertTrue(output[hosts[1]].exception is not None) self.assertEqual(output[hosts[1]].exception.host, hosts[1]) try: raise output[hosts[1]]['exception'] except ConnectionErrorException: pass else: raise Exception("Expected ConnectionError, got %s instead" % (output[hosts[1]]['exception'], ))
def test_zero_timeout(self): host = '127.0.0.2' server = OpenSSHServer(listen_ip=host, port=self.port) server.start_server() client = ParallelSSHClient([self.host, host], port=self.port, pkey=self.user_key, timeout=0) cmd = spawn(client.run_command, 'sleep 1', stop_on_errors=False) output = cmd.get(timeout=3) self.assertTrue(output[self.host].exception is None)
def test_join_timeout(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) output = client.run_command('echo me; sleep 1.5') self.assertRaises(Timeout, client.join, output, timeout=1) self.assertFalse(output[0].client.finished(output[0].channel)) self.assertFalse(output[0].channel.is_eof()) client.join(output, timeout=2) self.assertTrue(output[0].channel.is_eof()) self.assertTrue(client.finished(output))
def test_multiple_join_timeout(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) for _ in range(5): output = client.run_command(self.cmd, return_list=True) client.join(output, timeout=1, consume_output=True) for host_out in output: self.assertTrue(host_out.client.finished(host_out.channel)) output = client.run_command('sleep 2', return_list=True) self.assertRaises(Timeout, client.join, output, timeout=1, consume_output=True) for host_out in output: self.assertFalse(host_out.client.finished(host_out.channel))
def setUpClass(cls): _mask = int('0600') if version_info <= (2, ) else 0o600 os.chmod(PKEY_FILENAME, _mask) cls.host = '127.0.0.1' cls.port = 2422 cls.server = OpenSSHServer(listen_ip=cls.host, port=cls.port) cls.server.start_server() cls.cmd = 'echo me' cls.resp = u'me' cls.user_key = PKEY_FILENAME cls.user_pub_key = PUB_FILE cls.user = pwd.getpwuid(os.geteuid()).pw_name # Single client for all tests ensures that the client does not do # anything that causes server to disconnect the session and # affect all subsequent uses of the same session. cls.client = ParallelSSHClient([cls.host], pkey=PKEY_FILENAME, port=cls.port, num_retries=1)
def test_get_last_output(self): host = '127.0.0.9' server = OpenSSHServer(listen_ip=host, port=self.port) server.start_server() try: hosts = [self.host, host] client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key) self.assertTrue(client.cmds is None) self.assertTrue(client.get_last_output() is None) client.run_command(self.cmd) self.assertTrue(client.cmds is not None) self.assertEqual(len(client.cmds), len(hosts)) output = client.get_last_output() self.assertTrue(len(output), len(hosts)) client.join(output) for i, host in enumerate(hosts): self.assertEqual(output[i].host, host) exit_code = output[i].exit_code self.assertEqual(exit_code, 0) finally: server.stop()
def test_join_timeout(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) output = client.run_command('echo me; sleep 2') # Wait for long running command to start to avoid race condition time.sleep(.1) self.assertRaises(Timeout, client.join, output, timeout=1) self.assertFalse(output[self.host].channel.is_eof()) # Ensure command has actually finished - avoid race conditions time.sleep(2) client.join(output, timeout=3) self.assertTrue(output[self.host].channel.is_eof()) self.assertTrue(client.finished(output))
def test_default_finished(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) self.assertTrue(client.finished())