class TransportTest (unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) options = self.tc.get_security_options() options.ciphers = ('aes256-cbc',) options.digests = ('hmac-md5-96',) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.tc.set_hexdump(True) host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals(None, getattr(server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', server._global_request) def test_6_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) try: self.tc.connect(hostkey=public_host_key, username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assertEquals(BadAuthenticationType, etype) self.assertEquals(['publickey'], evalue.allowed_types) def test_7_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) try: self.tc.auth_password(username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assert_(issubclass(etype, SSHException)) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_8_multipart_auth(self): """ verify that multipart auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password(username='******', password='******') self.assertEquals(['publickey'], remain) key = DSSKey.from_private_key_file('tests/test_dss.key') remain = self.tc.auth_publickey(username='******', key=key) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_9_interactive_auth(self): """ verify keyboard-interactive auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) def handler(title, instructions, prompts): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts return ['cat'] remain = self.tc.auth_interactive('commie', handler) self.assertEquals(self.got_title, 'password') self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_A_interactive_auth_fallback(self): """ verify that a password auth attempt will fallback to "interactive" if password auth isn't supported but interactive is. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password('commie', 'cat') self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_B_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())
class RySftp: """ Copyright (C) 2020 Ryan Joyce <*****@*****.**> My first attempt at building a utility that securely connects to an SFTP server and downloads a file. I make use of the 'Paramiko' package to connect to SFTP servers, and shamelessly rip off the concept of Flask's application context. It also includes PGP decryption capability, provided by the 'python-gnupg' package. """ @catch_errors def __init__(self, **kwargs): self.user = kwargs.get("user", getenv("RYSFTP_USER")) self.password = kwargs.get("password", getenv("RYSFTP_PASSWORD")) self.hostname = kwargs.get("hostname", getenv("RYSFTP_HOSTNAME")) port = kwargs.get("port", getenv("RYSFTP_PORT")) self.port = int(port or 22) self.config = _RySftpConfig(**kwargs) self._t = Transport((self.hostname, self.port)) self._t.set_hexdump(False) self._sftp = None self._channel = None self._gpg = self.gpg_instance() self._connected = False self._lock = Lock() self._downloaded = [] self._uploaded = [] def __call__(self, *args, **kwargs): if kwargs.get("remotedir") or args: self.config.remotedir = args[0] if args else kwargs["remotedir"] return self @catch_errors def __enter__(self): self._t.connect(None, self.user, self.password) self._sftp = SFTPClient.from_transport(self._t) self._channel = self.ssh_channel() self._connected = True self._sftp.chdir(self.config.remotedir) ctx = _ry_ctx_stack.top if ctx is None or ctx.ry != self: ctx = self.ry_context() ctx.push() return self @catch_errors def __exit__(self, exc_type, exc_value, traceback): self._t.close() self._sftp.close() self._connected = False ctx = _ry_ctx_stack.top if ctx is not None: ctx.pop() def ssh_channel(self): """Get the paramiko Channel from the underlying sftp client """ return self._sftp.get_channel() def connects(f): @wraps(f) @catch_errors def wrapped(self, *args, **kwargs): ry_ctx = _ry_ctx_stack.top if kwargs.pop("thread", False) is True: with self.ry_context(): x = f(self, *args, **kwargs) return x if ry_ctx is not None and ry_ctx.ry == self and self._connected: return f(self, *args, **kwargs) raise OutsideAppContextError(_ry_ctx_err_msg) return wrapped def secure(f): """ NOT IMPLEMENTED YET WRAPPER TO CHECK IF WE WANT TO ENABLE GNUPG ENCRYPTION ON FILES TRANSFERRED Args: f ([type]): [description] Returns: [type]: [description] """ @wraps(f) @catch_errors def wrapped(self, *args, **kwargs): if kwargs.get("decrypt"): pass elif kwargs.get("encrypt"): pass return f(self, *args, **kwargs) return wrapped def ry_context(self): return RyContext(self) def gpg_instance(self): if self.config.gpgdir: return GPG(gnupghome=self.config.gpgdir) @connects def dirlist(self, full_remotepath=False): """ Returns a directory list of the passed remote directory (remotedir) If a remote directory is not passed in, it uses the home directory """ dirlist = self._sftp.listdir() if full_remotepath: dirlist = [f"{self.config.remotedir}/{d}" for d in dirlist] return dirlist def fstat(self, handle): """ Get stats about a file on the remote server via it's handle """ log.debug(f"stat request: [{handle}]") resp_type, msg = self._blocking_request(CMD_FSTAT, handle) if resp_type != CMD_ATTRS: raise SFTPError("Expected back attributes") return SFTPAttributes._from_msg(msg) def open(self, filename, mode="r"): """ Open a remote file, ``filename``, on the server for reading or writing. Args: filename (str): name of remote file to open mode (str): mode to open file in """ filename = self.encode_path(filename) pflags = 0 if "r" in mode: pflags |= SFTP_FLAG_READ if "w" in mode: pflags |= SFTP_FLAG_WRITE | SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC attrs = SFTPAttributes() resp_type, msg = self._blocking_request(CMD_OPEN, filename, pflags, attrs) if resp_type != CMD_HANDLE: raise SFTPError("Expected remote file handle") return msg.get_binary() @connects def read(self, handle, size, offset=0): """ Read ```size``` bytes from the remote file indicated by the server supplied ``handle`` :param str handle: remote file handle to read :param int size: bytes to read """ log.debug(f"read request: [{handle}] at byte [{offset}]") req_num = self._request(type(None), CMD_READ, handle, long(offset), size) if req_num: _request_stack[req_num] = (offset, size) return req_num @connects def write(self, handle, data, offset=0): """ Read ```data``` to the remote file indicated by ``handle`` :param str handle: remote file handle to write to :param bytes data: data to write """ log.debug(f'write request: [{handle}] at byte {offset}') return self._request(type(None), CMD_WRITE, handle, long(offset), data) def close(self, handle): """ Close the remote file :param str handle: remote file handle to close """ resp_type, msg = self._blocking_request(CMD_CLOSE, handle) if resp_type != CMD_STATUS: raise SFTPError("Error closing file") status = msg.get_int() log.debug(f'closed [{handle}] on server: {status}') return status @connects def download(self, file): """ Downloads a single file as specified in the passed remotefile parameter. 'remotefile' must be the full absolute path to the file on the server :param str file: file to download """ localfile = Path(self.config.localdir, file) if not self.config.overwrite_local and localfile.exists(): raise LocalFileExistsError(localfile) with open(localfile, "wb") as fw: handle = self.open(file) file_size = self.fstat(handle).st_size t1 = time.time() self._threaded_reader(handle, fw, file_size) t2 = time.time() self.close(handle) log.debug(f'download completed in {t2-t1} seconds at {round(file_size/(t2-t1)/1000, 2)} kB/s') with self._lock: self._downloaded.append(str(localfile)) return str(localfile) @connects def download_latest(self, dl_num=1, name_filter=[], **kwargs): """ Downloads the latest # of files as given in <dl_num> from the remote directory <remotedir> """ remote_list = sorted( self._sftp.listdir_attr(), key=lambda x: x.st_mtime, reverse=True ) to_download = [ f.filename for f in remote_list if S_ISREG(f.st_mode) and _apply_name_filter(f.filename, name_filter) ][:dl_num] [self.download(d) for d in to_download] return self._downloaded def download_all(self, **kwargs): """ Downloads all files in the given remote directory (remotedir). This needs to be tested for what happens when there are no files in the remote directory """ return self.download_latest(None, **kwargs) @connects def upload_latest(self, ul_num=1, name_filter=[], **kwargs): """ Uploads the latest # of files, specified by ``dl_num`` """ log.debug(f"UPLOADING THE LATEST {ul_num} FILES") if not name_filter: name_filter = ["?."] to_upload = [] for filter in name_filter: to_upload.extend([p for p in Path(self.config.localdir).glob(f"*{filter}*")]) to_upload = sorted( to_upload, key=lambda x: x.stat().st_mtime, reverse=True, )[:ul_num] [self.upload(u) for u in to_upload] return self._uploaded @connects def upload(self, file): """ Uploads a single file as specified in the passed `file` parameter. :parama file: file to download """ with open(file, "rb") as fr: file_size = os.fstat(fr.fileno()).st_size handle = self.open(Path(file).name, "w") t1 = time.time() self._threaded_writer(handle, fr, file_size) t2 = time.time() close = self.close(handle) log.debug(f'upload completed in {t2-t1} seconds at {round(file_size/(t2-t1)/1000, 2)} kB/s') with self._lock: self._uploaded.append(file) def encrypt(self, to_encrypt, recipients, fingerprint): output = Path(self.config.localdir, f"{Path(to_encrypt).name}.gpg") with open(to_encrypt, "rb") as f: result = self._gpg.encrypt_file( recipients=recipients, armor=False, file=f, output=str(output), sign=fingerprint, passphrase=self.config.gpg_passphrase, ) if not result.ok: raise RuntimeError("Error encrypting") return result def decrypt(self, to_decrypt, output_dir=None, overwrite=False): with open(to_decrypt, "rb") as open_f: result = self._gpg.decrypt_file( file=open_f, passphrase=self.config.gpg_passphrase, output=( # f"{self.tgt_dir}/" f"{no_extension[toDecrypt.index(f)]}" ), ) if not result.ok: raise RuntimeError("Bad Decryption") return result def encode_path(self, file): """Take a standalone filename, append it to the currently set remote directory, and convert it to a utf-8 bytestring Args: file (str/Path): standalone filename Returns: bytes: bytestring of remotedir/filename """ path = f"{self.config.remotedir}/{file}" return path.encode("utf-8") def _blocking_request(self, cmd, *args): """Make a request to the server and wait for a response back, blocking until it's received. Returns the response param int cmd: SSH FTP packet type param args: additional contents of packet """ req_num = self._request(type(None), cmd, *args) return self._get_response(req_num) def _request(self, expects, cmd, *args): """ Build an SSH FTP packet and send it to the server Args: expects (type): a type we expect back from server, if any cmd (int): SSH FTP packet type args: additional contents of packet """ with self._lock: if getattr(g, "req_num", False): req_num = g["req_num"] else: g["req_num"] = req_num = 0 msg = Message() msg.add_int(req_num) [_add_to_message(msg, a) for a in args] g["req_num"] += 1 self._sftp._send_packet(cmd, msg) return req_num def _get_response(self, wantsback=None): """ Read a packet and then process it into a ``Message`` object Returns the SSH packet type value of the response, along with the response itself wrapped in a Paramiko ``Message`` object :param int wantsback: the expected request # """ resp_type, msg = self._increment_response() req_num = msg.get_int() if req_num == wantsback: return resp_type, msg def _increment_response(self): with self._lock: resp_type, data = self._sftp._read_packet() return resp_type, Message(data) def _threaded_transfer(self, way, to_transfer): threads = [] for xfr in to_transfer: t = Thread(target=getattr(self, way), args=(xfr,), kwargs={"thread": True}) threads.append(t) t.start() [t.join() for t in threads] def _threaded_reader(self, handle, writer, size): futures = [] with self._lock: lo["expected_responses"] = math.ceil(size / MAX_PAYLOAD_SIZE) with ThreadPoolExecutor() as executor: n = 0 while n < size: chunk = min(MAX_PAYLOAD_SIZE, size - n) futures.append(executor.submit(self.read, handle, chunk, n, thread=True)) n += chunk requests = [f.result() for f in futures] for r in requests: resp_type, data = self._sftp._read_packet() if resp_type != CMD_DATA: raise SFTPError("Expected data") msg = Message(data) resp_num = msg.get_int() if resp_num in _request_stack: writer.seek(_request_stack[resp_num][0]) log.debug(f'write local at byte {_request_stack[resp_num][0]}') writer.write(msg.get_string()) def _threaded_writer(self, handle, reader, size): futures = [] with self._lock: lo["expected_responses"] = math.ceil(size / MAX_PAYLOAD_SIZE) with ThreadPoolExecutor() as executor: pos = 0 while pos < size: data = reader.read(MAX_PAYLOAD_SIZE) futures.append( executor.submit(self.write, handle, data, pos, thread=True) ) pos = reader.tell() for i in range(0, lo["expected_responses"]): resp_type, data = self._sftp._read_packet()
class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() self.sockc.link(self.socks) self.tc = Transport(self.sockc) self.ts = Transport(self.socks) def tearDown(self): self.tc.close() self.ts.close() self.socks.close() self.sockc.close() def setup_test_server(self): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) o.ciphers = ('aes256-cbc', 'blowfish-cbc') self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) try: o.ciphers = ('aes256-cbc', 'made-up-cipher') self.assert_(False) except ValueError: pass try: o.ciphers = 23 self.assert_(False) except TypeError: pass def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals( '207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', hexlify(key).upper()) def test_3_simple(self): """ verify that we can establish an ssh link with ourselves across the loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.assertEquals(None, self.tc.get_username()) self.assertEquals(None, self.ts.get_username()) self.assertEquals(False, self.tc.is_authenticated()) self.assertEquals(False, self.ts.is_authenticated()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.tc.get_username()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.tc.is_authenticated()) self.assertEquals(True, self.ts.is_authenticated()) def test_4_special(self): """ verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) options = self.tc.get_security_options() options.ciphers = ('aes256-cbc', ) options.digests = ('hmac-md5-96', ) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ self.tc.set_hexdump(True) host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals(None, getattr(server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) self.assertEquals('*****@*****.**', server._global_request) def test_6_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) try: self.tc.connect(hostkey=public_host_key, username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assertEquals(BadAuthenticationType, etype) self.assertEquals(['publickey'], evalue.allowed_types) def test_7_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) try: self.tc.auth_password(username='******', password='******') self.assert_(False) except: etype, evalue, etb = sys.exc_info() self.assert_(issubclass(etype, SSHException)) self.tc.auth_password(username='******', password='******') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_8_multipart_auth(self): """ verify that multipart auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password(username='******', password='******') self.assertEquals(['publickey'], remain) key = DSSKey.from_private_key_file('tests/test_dss.key') remain = self.tc.auth_publickey(username='******', key=key) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_9_interactive_auth(self): """ verify keyboard-interactive auth works. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) def handler(title, instructions, prompts): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts return ['cat'] remain = self.tc.auth_interactive('commie', handler) self.assertEquals(self.got_title, 'password') self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_A_interactive_auth_fallback(self): """ verify that a password auth attempt will fallback to "interactive" if password auth isn't supported but interactive is. """ host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, server) self.tc.ultra_debug = True self.tc.connect(hostkey=public_host_key) remain = self.tc.auth_password('commie', 'cat') self.assertEquals([], remain) event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) def test_B_exec_command(self): """ verify that exec_command() does something reasonable. """ self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) try: chan.exec_command('no') self.assert_(False) except SSHException, x: pass chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) f = chan.makefile_stderr() self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() chan.set_combine_stderr(True) f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline())