class SshConnection: """ tsc ssh connection """ def __init__(self, name, host, port=22, user=None, password=None): """ init """ self.connection = None self.session = None self.host = host self.port = port self.loginAccount = { 'user': user, 'pass': password } self.name = name self.timeout = 4 self.nbytes = 4096 try: self.connection = Transport((self.host, self.port)) self.connection.connect(username=self.loginAccount['user'], password=self.loginAccount['pass']) self.connection.set_keepalive(self.timeout) self.session = self.connection.open_channel(kind='session') print('ssh connection created!') except Exception, e: self.__del__() print('ssh connection failed: {er}'.format(er=e))
class TransportTest(ParamikoTest): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_((b'aes256-cbc', b'blowfish-cbc') != o.ciphers) o.ciphers = (b'aes256-cbc', b'blowfish-cbc') self.assertEquals((b'aes256-cbc', b'blowfish-cbc'), o.ciphers) try: o.ciphers = (b'aes256-cbc', b'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = unhexlify(b'0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key(b'C', 32) self.assertEquals(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = (b'aes256-cbc',) options.digests = (b'hmac-md5-96',) self.setup_test_server(client_options=force_algorithms) self.assertEquals(b'aes256-cbc', self.tc.local_cipher) self.assertEquals(b'aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals(b'*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException as x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send(b'Hello there.\n') schan.send_stderr(b'This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'', f.readline()) f = chan.makefile_stderr() self.assertEquals(b'This is on stderr.\n', f.readline()) self.assertEquals(b'', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send(b'Hello there.\n') schan.send_stderr(b'This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'This is on stderr.\n', f.readline()) self.assertEquals(b'', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send(b'communist j. cat\n') f = schan.makefile() self.assertEquals(b'communist j. cat\n', f.readline()) chan.close() self.assertEquals(b'', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel(b'bogus') self.fail('expected exception') except ChannelException as x: self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send(b'Hello there.\n') self.assert_(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEquals(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.send(b'hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'', chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEquals(False, p._closed) chan.close() # ...and now is closed. self.assertEquals(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEquals(self.tc.H, self.tc.session_id) for i in range(20): chan.send(b'x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEquals(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = (b'zlib',) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command(b'yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send(b'x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assert_(bytes2 - bytes < 1024) self.assertEquals(52, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command(b'yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr): requested.append(addr) self.tc._queue_incoming_channel(c) self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEquals(0, self.server._x11_screen_number) self.assertEquals(b'MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEquals(cookie, self.server._x11_auth_cookie) self.assertEquals(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEquals('localhost', requested[0][0]) self.assertEquals(6093, requested[0][1]) x11_server.send(b'hello') self.assertEquals(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr, server_addr): requested.append(origin_addr) requested.append(server_addr) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEquals(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect((b'127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send(b'hello') self.assertEquals(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward(b'127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel(b'direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEquals(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.send_stderr(b'hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEquals(chan.send_ready(), True) total = 0 K = b'*' * 1024 while total < 1024 * 1024: chan.send(K) total += len(K) if not chan.send_ready(): break self.assert_(total < 1024 * 1024) schan.close() chan.close() self.assertEquals(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1+self.iterations): if self.done_event.isSet(): break self.watchdog_event.set() #print i, "SEND" self.chan.send(b"x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.isSet(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST).encode()) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.isSet(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.isSet(): break if not event.isSet(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close()
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None ): host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, username="******", password="******", ) self.tc.connect(**connect_kwargs) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers) o.ciphers = ("aes256-cbc", "blowfish-cbc") self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers) try: o.ciphers = ("aes256-cbc", "made-up-cipher") self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def testb_security_options_reset(self): o = self.tc.get_security_options() # should not throw any exceptions o.ciphers = o.ciphers o.digests = o.digests o.key_types = o.key_types o.kex = o.kex o.compression = o.compression def test_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 # noqa self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" # noqa self.tc.session_id = self.tc.H key = self.tc._compute_key("C", 32) self.assertEqual( b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", # noqa hexlify(key).upper(), ) def test_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect( hostkey=public_host_key, username="******", password="******" ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual("slowdive", self.tc.get_username()) self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def testa_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect( hostkey=public_host_key, username="******", password="******" ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ("aes256-cbc",) options.digests = ("hmac-md5-96",) self.setup_test_server(client_options=force_algorithms) self.assertEqual("aes256-cbc", self.tc.local_cipher) self.assertEqual("aes256-cbc", self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) @slow def test_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, "_global_request", None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual("*****@*****.**", self.server._global_request) def test_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command( b"command contains \xfc and is not a valid UTF-8 string" ) self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) schan.send("Hello there.\n") schan.send_stderr("This is on stderr.\n") schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) f = chan.makefile_stderr() self.assertEqual("This is on stderr.\n", f.readline()) self.assertEqual("", f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) schan.send("Hello there.\n") schan.send_stderr("This is on stderr.\n") schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("This is on stderr.\n", f.readline()) self.assertEqual("", f.readline()) def testa_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command("yes") schan.send("Hello there.\n") schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) def test_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send("communist j. cat\n") f = schan.makefile() self.assertEqual("communist j. cat\n", f.readline()) chan.close() self.assertEqual("", f.readline()) def test_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: self.tc.open_channel("bogus") self.fail("expected exception") except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command("yes") schan.send("Hello there.\n") self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b"hello\n", chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send("x" * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ("zlib",) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send("x" * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] mac_size = self.tc._mac_info[self.tc.local_mac]["size"] # tests show this is actually compressed to *52 bytes*! including # packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) chan.close() schan.close() def test_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual( None, getattr(self.server, "_x11_screen_number", None) ) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(("localhost", 6093)) x11_client = self.tc.accept() self.assertEqual("localhost", requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send("hello") self.assertEqual(b"hello", x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward("127.0.0.1", 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(("127.0.0.1", port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel( ss.getsockname(), ss.getpeername() ) cch = self.tc.accept() sch.send("hello") self.assertEqual(b"hello", cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward("127.0.0.1", port) self.assertTrue(self.server._listen is None) def test_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(("127.0.0.1", 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel( "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) ) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b"Hello!\n") ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b"Hello!\n", cs.recv(7)) cs.close() def test_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b"hello\n", chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = "*" * 1024 limit = 1 + (64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST # is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is # sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls # Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() # print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = ( self.tc._handler_table.copy() ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [ (4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2 ** 32, MAX_WINDOW_SIZE), ]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [ (32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2 ** 32, MAX_WINDOW_SIZE), ]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow def test_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ # Tweak client Transport instance's Packetizer instance so # its read_message() sleeps a bit. This helps prevent race conditions # where the client Transport's timeout timer thread doesn't even have # time to get scheduled before the main client thread finishes # handshaking with the server. # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises( EOFError, self.tc.connect, hostkey=public_host_key, username="******", password="******", ) def test_select_after_close(self): """ verify that select works when a channel is already closed. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) schan.close() # give client a moment to receive close notification time.sleep(0.1) r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) def test_channel_send_misc(self): """ verify behaviours sending various instances to a channel """ self.setup_test_server() text = u"\xa7 slice me nicely" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # TypeError raised on non string or buffer type self.assertRaises(TypeError, chan.send, object()) self.assertRaises(TypeError, chan.sendall, object()) # sendall() accepts a unicode instance chan.sendall(text) expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) @needs_builtin("buffer") def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts buffer instances sent = 0 while sent < len(data): sent += chan.send(buffer(data, sent, 8)) # noqa self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a buffer instance chan.sendall(buffer(data)) # noqa self.assertEqual(sfile.read(len(data)), data) @needs_builtin("memoryview") def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts memoryview slices sent = 0 view = memoryview(data) while sent < len(view): sent += chan.send(view[sent : sent + 8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance chan.sendall(memoryview(data)) self.assertEqual(sfile.read(len(data)), data) def test_server_rejects_open_channel_without_auth(self): try: self.setup_test_server(connect_kwargs={}) self.tc.open_session() except ChannelException as e: assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: assert False, "Did not raise ChannelException!" def test_server_rejects_arbitrary_global_request_without_auth(self): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. self.tc.global_request("acceptable") # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back # to None; if it had succeeded, it would be the response Message. err = "Unauthed global response incorrectly succeeded!" assert self.tc.global_response is None, err def test_server_rejects_port_forward_without_auth(self): # NOTE: at protocol level port forward requests are treated same as a # regular global request, but Paramiko server implements a special-case # method for it, so it gets its own test. (plus, THAT actually raises # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: self.tc.request_port_forward("localhost", 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: assert False, "Did not raise SSHException!" def _send_unimplemented(self, server_is_sender): self.setup_test_server() sender, recipient = self.tc, self.ts if server_is_sender: sender, recipient = self.ts, self.tc recipient._send_message = Mock() msg = Message() msg.add_byte(cMSG_UNIMPLEMENTED) sender._send_message(msg) # TODO: I hate this but I literally don't see a good way to know when # the recipient has received the sender's message (there are no # existing threading events in play that work for this), esp in this # case where we don't WANT a response (as otherwise we could # potentially try blocking on the sender's receipt of a reply...maybe). time.sleep(0.1) assert not recipient._send_message.called def test_server_does_not_respond_to_MSG_UNIMPLEMENTED(self): self._send_unimplemented(server_is_sender=False) def test_client_does_not_respond_to_MSG_UNIMPLEMENTED(self): self._send_unimplemented(server_is_sender=True) def _send_client_message(self, message_type): self.setup_test_server(connect_kwargs={}) self.ts._send_message = Mock() # NOTE: this isn't 100% realistic (most of these message types would # have actual other fields in 'em) but it suffices to test the level of # message dispatch we're interested in here. msg = Message() # TODO: really not liking the whole cMSG_XXX vs MSG_XXX duality right # now, esp since the former is almost always just byte_chr(the # latter)...but since that's the case... msg.add_byte(byte_chr(message_type)) self.tc._send_message(msg) # No good way to actually wait for server action (see above tests re: # MSG_UNIMPLEMENTED). Grump. time.sleep(0.1) def _expect_unimplemented(self): # Ensure MSG_UNIMPLEMENTED was sent (implies it hit end of loop instead # of truly handling the given message). # NOTE: When bug present, this will actually be the first thing that # fails (since in many cases actual message handling doesn't involve # sending a message back right away). assert self.ts._send_message.call_count == 1 reply = self.ts._send_message.call_args[0][0] reply.rewind() # Because it's pre-send, not post-receive assert reply.get_byte() == cMSG_UNIMPLEMENTED def test_server_transports_reject_client_message_types(self): # TODO: handle Transport's own tables too, not just its inner auth # handler's table. See TODOs in auth_handler.py for message_type in AuthHandler._client_handler_table: self._send_client_message(message_type) self._expect_unimplemented() # Reset for rest of loop self.tearDown() self.setUp() def test_server_rejects_client_MSG_USERAUTH_SUCCESS(self): self._send_client_message(MSG_USERAUTH_SUCCESS) # Sanity checks assert not self.ts.authenticated assert not self.ts.auth_handler.authenticated # Real fix's behavior self._expect_unimplemented()
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEqual( b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc', ) options.digests = ('hmac-md5-96', ) self.setup_test_server(client_options=force_algorithms) self.assertEqual('aes256-cbc', self.tc.local_cipher) self.assertEqual('aes256-cbc', self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) f = chan.makefile_stderr() self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command('yes') schan.send('Hello there.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() self.assertEqual('communist j. cat\n', f.readline()) chan.close() self.assertEqual('', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel('bogus') self.fail('expected exception') except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send('Hello there.\n') self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send('x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ('zlib', ) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send('x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(52, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEqual('localhost', requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send('hello') self.assertEqual(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(('127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send('hello') self.assertEqual(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward('127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = '*' * 1024 limit = 1 + (64 * 2**15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() #print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy( ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_J_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_window_size(val), correct) def test_L_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises(EOFError, self.tc.connect, hostkey=public_host_key, username='******', password='******')
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None, ): host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, username='******', password='******', ) self.tc.connect(**connect_kwargs) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def test_1b_security_options_reset(self): o = self.tc.get_security_options() # should not throw any exceptions o.ciphers = o.ciphers o.digests = o.digests o.key_types = o.key_types o.kex = o.kex o.compression = o.compression def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc',) options.digests = ('hmac-md5-96',) self.setup_test_server(client_options=force_algorithms) self.assertEqual('aes256-cbc', self.tc.local_cipher) self.assertEqual('aes256-cbc', self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) @slow def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) f = chan.makefile_stderr() self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command('yes') schan.send('Hello there.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() self.assertEqual('communist j. cat\n', f.readline()) chan.close() self.assertEqual('', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel('bogus') self.fail('expected exception') except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send('Hello there.\n') self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send('x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ('zlib',) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send('x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] mac_size = self.tc._mac_info[self.tc.local_mac]['size'] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEqual('localhost', requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send('hello') self.assertEqual(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(('127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send('hello') self.assertEqual(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward('127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = '*' * 1024 limit = 1+(64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1+self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() #print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_J_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow def test_L_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ # Tweak client Transport instance's Packetizer instance so # its read_message() sleeps a bit. This helps prevent race conditions # where the client Transport's timeout timer thread doesn't even have # time to get scheduled before the main client thread finishes # handshaking with the server. # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises(EOFError, self.tc.connect, hostkey=public_host_key, username='******', password='******') def test_M_select_after_close(self): """ verify that select works when a channel is already closed. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) schan.close() # give client a moment to receive close notification time.sleep(0.1) r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) def test_channel_send_misc(self): """ verify behaviours sending various instances to a channel """ self.setup_test_server() text = u"\xa7 slice me nicely" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # TypeError raised on non string or buffer type self.assertRaises(TypeError, chan.send, object()) self.assertRaises(TypeError, chan.sendall, object()) # sendall() accepts a unicode instance chan.sendall(text) expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) @needs_builtin('buffer') def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() data = 3 * b'some test data\n whole' with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts buffer instances sent = 0 while sent < len(data): sent += chan.send(buffer(data, sent, 8)) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a buffer instance chan.sendall(buffer(data)) self.assertEqual(sfile.read(len(data)), data) @needs_builtin('memoryview') def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() data = 3 * b'some test data\n whole' with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts memoryview slices sent = 0 view = memoryview(data) while sent < len(view): sent += chan.send(view[sent:sent+8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance chan.sendall(memoryview(data)) self.assertEqual(sfile.read(len(data)), data) def test_server_rejects_open_channel_without_auth(self): try: self.setup_test_server(connect_kwargs={}) self.tc.open_session() except ChannelException as e: assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: assert False, "Did not raise ChannelException!" def test_server_rejects_arbitrary_global_request_without_auth(self): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. self.tc.global_request('acceptable') # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back # to None; if it had succeeded, it would be the response Message. err = "Unauthed global response incorrectly succeeded!" assert self.tc.global_response is None, err def test_server_rejects_port_forward_without_auth(self): # NOTE: at protocol level port forward requests are treated same as a # regular global request, but Paramiko server implements a special-case # method for it, so it gets its own test. (plus, THAT actually raises # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: self.tc.request_port_forward('localhost', 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: assert False, "Did not raise SSHException!"
class SSH(object): """Wrapper for Paramiko Transport and Channel for Expect-like sessions""" def __init__(self, user, passwd = ""): """initialize SSH wrapper but do not connect""" self.user = user self.passwd = passwd self.prompt_re = "([a-z]+@[a-zA-Z0-9\.\-\_]+[>#%])" self.reset_timeout_on_newlines = True self._timeout_sec = 10 self.quiet = 0 self.received_text = "" def connect(self, hostname_address): """connect to SSH target using paramiko Transport/Channel""" self._transport = Transport(hostname_address) self._transport.connect(username=self.user, password=self.passwd) self.ssh_channel = self._transport.open_channel("session") self.pty = self.ssh_channel.get_pty() self.ssh_channel.invoke_shell() self.wait_for_prompt() self.ssh_channel.send("set cli screen-length 0") def txrx_status(self): """returns send and receive status as string formatted for screen""" chan = self.ssh_channel return "Send Ready: {}, Receive Ready: {}".format(chan.send_ready(), chan.recv_ready()) @property def timeout(self): """timeout_sec getter""" return self._timeout_sec @timeout.setter def timeout(self, new_timeout_sec): """timeout_sec setter""" self._timeout_sec = new_timeout_sec def wait_for_prompt(self): """call self.wait_for_regex using self.prompt_re as argument""" return self.wait_for_regex(self.prompt_re) def wait_for_regex(self, expression, wait_sec=0.1): """ Loop receive and wait until a regular expression is matched in output - raise an exception if we hit the timeout """ chan = self.ssh_channel done = False start_time = time.clock() while ((time.clock() - start_time) <= self.timeout) and not done: time.sleep(wait_sec) if chan.recv_ready(): while chan.recv_ready(): this_receive = chan.recv(1024) self.received_text = self.received_text + this_receive # print to screen if quiet is not set if not self.quiet: stdout.write(this_receive) stdout.flush() # reset the start_time if we encounter newlines and self.reset_timeout_on_newlines is True if re.search("\n",self.received_text) and self.reset_timeout_on_newlines: start_time = time.clock() #else: #if not chan.recv_ready, then do nothing done = re.search(expression, self.received_text) # if done is not True at this point, we consider it to be a timeout action if not done: raise NetDevError("Timeout waiting for expression match: '{}'".format(expression)) else: return self.received_text def _sendline(self, line): """send a single line""" chan = self.ssh_channel if chan.send_ready(): chan.send(line.strip() + "\n") else: raise NetDevError("Attempted to send when send not ready") def send(self, textblock): """work through a textblock and send one line at at time""" configtext = textblock.strip() for line in textblock.splitlines(): self._sendline(line) self.received_text = "" self.wait_for_prompt() self.received_text = ""
class SSH(object): """Wrapper for Paramiko Transport and Channel for Expect-like sessions""" def __init__(self, user, passwd=""): """initialize SSH wrapper but do not connect""" self.user = user self.passwd = passwd self.prompt_re = "([a-z]+@[a-zA-Z0-9\.\-\_]+[>#%])" self.reset_timeout_on_newlines = True self._timeout_sec = 10 self.quiet = 0 self.received_text = "" def connect(self, hostname_address): """connect to SSH target using paramiko Transport/Channel""" self._transport = Transport(hostname_address) self._transport.connect(username=self.user, password=self.passwd) self.ssh_channel = self._transport.open_channel("session") self.pty = self.ssh_channel.get_pty() self.ssh_channel.invoke_shell() self.wait_for_prompt() self.ssh_channel.send("set cli screen-length 0") def txrx_status(self): """returns send and receive status as string formatted for screen""" chan = self.ssh_channel return "Send Ready: {}, Receive Ready: {}".format( chan.send_ready(), chan.recv_ready()) @property def timeout(self): """timeout_sec getter""" return self._timeout_sec @timeout.setter def timeout(self, new_timeout_sec): """timeout_sec setter""" self._timeout_sec = new_timeout_sec def wait_for_prompt(self): """call self.wait_for_regex using self.prompt_re as argument""" return self.wait_for_regex(self.prompt_re) def wait_for_regex(self, expression, wait_sec=0.1): """ Loop receive and wait until a regular expression is matched in output - raise an exception if we hit the timeout """ chan = self.ssh_channel done = False start_time = time.clock() while ((time.clock() - start_time) <= self.timeout) and not done: time.sleep(wait_sec) if chan.recv_ready(): while chan.recv_ready(): this_receive = chan.recv(1024) self.received_text = self.received_text + this_receive # print to screen if quiet is not set if not self.quiet: stdout.write(this_receive) stdout.flush() # reset the start_time if we encounter newlines and self.reset_timeout_on_newlines is True if re.search( "\n", self.received_text) and self.reset_timeout_on_newlines: start_time = time.clock() #else: #if not chan.recv_ready, then do nothing done = re.search(expression, self.received_text) # if done is not True at this point, we consider it to be a timeout action if not done: raise NetDevError( "Timeout waiting for expression match: '{}'".format( expression)) else: return self.received_text def _sendline(self, line): """send a single line""" chan = self.ssh_channel if chan.send_ready(): chan.send(line.strip() + "\n") else: raise NetDevError("Attempted to send when send not ready") def send(self, textblock): """work through a textblock and send one line at at time""" configtext = textblock.strip() for line in textblock.splitlines(): self._sendline(line) self.received_text = "" self.wait_for_prompt() self.received_text = ""
#!/usr/bin/env python from paramiko import Transport try: transport = Transport(("dhrona.net", 12276)) transport.connect(username="******", password="******") channel = transport.open_channel("session") channel.exec_command("cat /etc/passwd") instream = channel.makefile("r") for line in instream: print(line.strip().upper()) except Exception as e: print("Caught some exception: ", e) finally: instream.close() channel.close() transport.close()