class TransportTest (unittest.TestCase):

    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self):
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        self.server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, self.server)
        self.tc.connect(hostkey=public_host_key)
        self.tc.auth_password(username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEquals(type(o), SecurityOptions)
        self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assert_(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assert_(False)
        except TypeError:
            pass
            
    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
        self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
                          hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.assertEquals(None, self.tc.get_username())
        self.assertEquals(None, self.ts.get_username())
        self.assertEquals(False, self.tc.is_authenticated())
        self.assertEquals(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('slowdive', self.tc.get_username())
        self.assertEquals('slowdive', self.ts.get_username())
        self.assertEquals(True, self.tc.is_authenticated())
        self.assertEquals(True, self.ts.is_authenticated())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        options = self.tc.get_security_options()
        options.ciphers = ('aes256-cbc',)
        options.digests = ('hmac-md5-96',)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('aes256-cbc', self.tc.local_cipher)
        self.assertEquals('aes256-cbc', self.tc.remote_cipher)
        self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
        self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
        
        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.tc.set_hexdump(True)
        
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        
        self.assertEquals(None, getattr(server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEquals('*****@*****.**', server._global_request)
        
    def test_6_bad_auth_type(self):
        """
        verify that we get the right exception when an unsupported auth
        type is requested.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        try:
            self.tc.connect(hostkey=public_host_key,
                            username='******', password='******')
            self.assert_(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertEquals(BadAuthenticationType, etype)
            self.assertEquals(['publickey'], evalue.allowed_types)

    def test_7_bad_password(self):
        """
        verify that a bad password gets the right exception, and that a retry
        with the right password works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        try:
            self.tc.auth_password(username='******', password='******')
            self.assert_(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assert_(issubclass(etype, SSHException))
        self.tc.auth_password(username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
    
    def test_8_multipart_auth(self):
        """
        verify that multipart auth works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        remain = self.tc.auth_password(username='******', password='******')
        self.assertEquals(['publickey'], remain)
        key = DSSKey.from_private_key_file('tests/test_dss.key')
        remain = self.tc.auth_publickey(username='******', key=key)
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_9_interactive_auth(self):
        """
        verify keyboard-interactive auth works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)

        def handler(title, instructions, prompts):
            self.got_title = title
            self.got_instructions = instructions
            self.got_prompts = prompts
            return ['cat']
        remain = self.tc.auth_interactive('commie', handler)
        self.assertEquals(self.got_title, 'password')
        self.assertEquals(self.got_prompts, [('Password', False)])
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        
    def test_A_interactive_auth_fallback(self):
        """
        verify that a password auth attempt will fallback to "interactive"
        if password auth isn't supported but interactive is.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        remain = self.tc.auth_password('commie', 'cat')
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
            
    def test_B_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command('no')
            self.assert_(False)
        except SSHException, x:
            pass
        
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('', f.readline())
        f = chan.makefile_stderr()
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())
        
        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)        
        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())
Beispiel #2
0
class RySftp:
    """
    Copyright (C) 2020  Ryan Joyce <*****@*****.**>

    My first attempt at building a utility that securely connects
    to an SFTP server and downloads a file.

    I make use of the 'Paramiko' package to connect to SFTP servers, and
    shamelessly rip off the concept of Flask's application context.

    It also includes PGP decryption capability, provided by the
    'python-gnupg' package.
    """

    @catch_errors
    def __init__(self, **kwargs):
        self.user = kwargs.get("user", getenv("RYSFTP_USER"))
        self.password = kwargs.get("password", getenv("RYSFTP_PASSWORD"))
        self.hostname = kwargs.get("hostname", getenv("RYSFTP_HOSTNAME"))
        port = kwargs.get("port", getenv("RYSFTP_PORT"))
        self.port = int(port or 22)
        self.config = _RySftpConfig(**kwargs)

        self._t = Transport((self.hostname, self.port))
        self._t.set_hexdump(False)

        self._sftp = None
        self._channel = None
        self._gpg = self.gpg_instance()
        self._connected = False
        self._lock = Lock()

        self._downloaded = []
        self._uploaded = []

    def __call__(self, *args, **kwargs):
        if kwargs.get("remotedir") or args:
            self.config.remotedir = args[0] if args else kwargs["remotedir"]
        return self

    @catch_errors
    def __enter__(self):
        self._t.connect(None, self.user, self.password)
        self._sftp = SFTPClient.from_transport(self._t)
        self._channel = self.ssh_channel()
        self._connected = True
        self._sftp.chdir(self.config.remotedir)

        ctx = _ry_ctx_stack.top
        if ctx is None or ctx.ry != self:
            ctx = self.ry_context()
            ctx.push()
        return self

    @catch_errors
    def __exit__(self, exc_type, exc_value, traceback):
        self._t.close()
        self._sftp.close()
        self._connected = False

        ctx = _ry_ctx_stack.top
        if ctx is not None:
            ctx.pop()

    def ssh_channel(self):
        """Get the paramiko Channel from the underlying sftp client

        """
        return self._sftp.get_channel()

    def connects(f):
        @wraps(f)
        @catch_errors
        def wrapped(self, *args, **kwargs):
            ry_ctx = _ry_ctx_stack.top
            if kwargs.pop("thread", False) is True:
                with self.ry_context():
                    x = f(self, *args, **kwargs)
                    return x
            if ry_ctx is not None and ry_ctx.ry == self and self._connected:
                return f(self, *args, **kwargs)
            raise OutsideAppContextError(_ry_ctx_err_msg)
        return wrapped

    def secure(f):
        """
        NOT IMPLEMENTED YET
        WRAPPER TO CHECK IF WE WANT TO ENABLE GNUPG ENCRYPTION ON FILES
        TRANSFERRED

        Args:
            f ([type]): [description]

        Returns:
            [type]: [description]
        """

        @wraps(f)
        @catch_errors
        def wrapped(self, *args, **kwargs):
            if kwargs.get("decrypt"):
                pass
            elif kwargs.get("encrypt"):
                pass
            return f(self, *args, **kwargs)
        return wrapped

    def ry_context(self):
        return RyContext(self)

    def gpg_instance(self):
        if self.config.gpgdir:
            return GPG(gnupghome=self.config.gpgdir)

    @connects
    def dirlist(self, full_remotepath=False):
        """
        Returns a directory list of the passed remote directory (remotedir)

        If a remote directory is not passed in, it uses the home directory
        """
        dirlist = self._sftp.listdir()
        if full_remotepath:
            dirlist = [f"{self.config.remotedir}/{d}" for d in dirlist]
        return dirlist

    def fstat(self, handle):
        """
        Get stats about a file on the remote server via it's handle

        """
        log.debug(f"stat request: [{handle}]")
        resp_type, msg = self._blocking_request(CMD_FSTAT, handle)
        if resp_type != CMD_ATTRS:
            raise SFTPError("Expected back attributes")
        return SFTPAttributes._from_msg(msg)

    def open(self, filename, mode="r"):
        """
        Open a remote file, ``filename``, on the server for reading
        or writing.

        Args:
            filename (str): name of remote file to open
            mode (str): mode to open file in
        """
        filename = self.encode_path(filename)
        pflags = 0
        if "r" in mode:
            pflags |= SFTP_FLAG_READ
        if "w" in mode:
            pflags |= SFTP_FLAG_WRITE | SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC
        attrs = SFTPAttributes()
        resp_type, msg = self._blocking_request(CMD_OPEN, filename, pflags, attrs)
        if resp_type != CMD_HANDLE:
            raise SFTPError("Expected remote file handle")
        return msg.get_binary()

    @connects
    def read(self, handle, size, offset=0):
        """
        Read ```size``` bytes from the remote file indicated by the
        server supplied ``handle``

        :param str handle: remote file handle to read
        :param int size: bytes to read
        """
        log.debug(f"read request: [{handle}] at byte [{offset}]")
        req_num = self._request(type(None), CMD_READ, handle, long(offset), size)
        if req_num:
            _request_stack[req_num] = (offset, size)
        return req_num

    @connects
    def write(self, handle, data, offset=0):
        """
        Read ```data```  to the remote file indicated by ``handle``

        :param str handle: remote file handle to write to
        :param bytes data: data to write
        """
        log.debug(f'write request: [{handle}] at byte {offset}')
        return self._request(type(None), CMD_WRITE, handle, long(offset), data) 

    def close(self, handle):
        """
        Close the remote file
        
        :param str handle: remote file handle to close
        """
        resp_type, msg = self._blocking_request(CMD_CLOSE, handle)
        if resp_type != CMD_STATUS:
            raise SFTPError("Error closing file")
        status = msg.get_int()
        log.debug(f'closed [{handle}] on server: {status}')
        return status

    @connects
    def download(self, file):
        """
        Downloads a single file as specified in the passed remotefile
        parameter. 'remotefile' must be the full absolute path to the
        file on the server

        :param str file: file to download
        """
        localfile = Path(self.config.localdir, file)
        if not self.config.overwrite_local and localfile.exists():
            raise LocalFileExistsError(localfile)
        with open(localfile, "wb") as fw:
            handle = self.open(file)
            file_size = self.fstat(handle).st_size
            t1 = time.time()
            self._threaded_reader(handle, fw, file_size)
            t2 = time.time()
        self.close(handle)
        log.debug(f'download completed in {t2-t1} seconds at {round(file_size/(t2-t1)/1000, 2)} kB/s')
        with self._lock:
            self._downloaded.append(str(localfile))
        return str(localfile)

    @connects
    def download_latest(self, dl_num=1, name_filter=[], **kwargs):
        """
        Downloads the latest # of files as given in <dl_num> from the
        remote directory <remotedir>
        """
        remote_list = sorted(
            self._sftp.listdir_attr(), key=lambda x: x.st_mtime, reverse=True
        )
        to_download = [
            f.filename
            for f in remote_list
            if S_ISREG(f.st_mode) and _apply_name_filter(f.filename, name_filter)
        ][:dl_num]
        [self.download(d) for d in to_download]
        return self._downloaded

    def download_all(self, **kwargs):
        """
        Downloads all files in the given remote directory (remotedir).

        This needs to be tested for what happens when there are no files
        in the remote directory
        """
        return self.download_latest(None, **kwargs)

    @connects
    def upload_latest(self, ul_num=1, name_filter=[], **kwargs):
        """
        Uploads the latest # of files, specified by ``dl_num``
        """
        log.debug(f"UPLOADING THE LATEST {ul_num} FILES")
        if not name_filter:
            name_filter = ["?."]
        to_upload = []
        for filter in name_filter:
            to_upload.extend([p for p in Path(self.config.localdir).glob(f"*{filter}*")])
        to_upload = sorted(
            to_upload,
            key=lambda x: x.stat().st_mtime,
            reverse=True,
        )[:ul_num]
        [self.upload(u) for u in to_upload]
        return self._uploaded

    @connects
    def upload(self, file):
        """
        Uploads a single file as specified in the passed `file`
        parameter.

        :parama file: file to download
        """
        with open(file, "rb") as fr:
            file_size = os.fstat(fr.fileno()).st_size
            handle = self.open(Path(file).name, "w")
            t1 = time.time()
            self._threaded_writer(handle, fr, file_size)
            t2 = time.time()
        close = self.close(handle)
        log.debug(f'upload completed in {t2-t1} seconds at {round(file_size/(t2-t1)/1000, 2)} kB/s')
        with self._lock:
            self._uploaded.append(file)

    def encrypt(self, to_encrypt, recipients, fingerprint):
        output = Path(self.config.localdir, f"{Path(to_encrypt).name}.gpg")
        with open(to_encrypt, "rb") as f:
            result = self._gpg.encrypt_file(
                recipients=recipients,
                armor=False,
                file=f,
                output=str(output),
                sign=fingerprint,
                passphrase=self.config.gpg_passphrase,
            )
        if not result.ok:
            raise RuntimeError("Error encrypting")
        return result

    def decrypt(self, to_decrypt, output_dir=None, overwrite=False):
        with open(to_decrypt, "rb") as open_f:
            result = self._gpg.decrypt_file(
                file=open_f,
                passphrase=self.config.gpg_passphrase,
                output=(
                    # f"{self.tgt_dir}/" f"{no_extension[toDecrypt.index(f)]}"
                ),
            )
        if not result.ok:
            raise RuntimeError("Bad Decryption")
        return result

    def encode_path(self, file):
        """Take a standalone filename, append it to the currently set remote
        directory, and convert it to a utf-8 bytestring

        Args:
            file (str/Path): standalone filename

        Returns:
            bytes: bytestring of remotedir/filename
        """
        path = f"{self.config.remotedir}/{file}"
        return path.encode("utf-8")

    def _blocking_request(self, cmd, *args):
        """Make a request to the server and wait for a response back, blocking
        until it's received. Returns the response

        param int cmd: SSH FTP packet type
        param args: additional contents of packet

        """
        req_num = self._request(type(None), cmd, *args)
        return self._get_response(req_num)

    def _request(self, expects, cmd, *args):
        """
        Build an SSH FTP packet and send it to the server

        Args:
            expects (type): a type we expect back from server, if any
            cmd (int): SSH FTP packet type
            args: additional contents of packet
        """
        with self._lock:
            if getattr(g, "req_num", False):
                req_num = g["req_num"]
            else:
                g["req_num"] = req_num = 0
            msg = Message()
            msg.add_int(req_num)
            [_add_to_message(msg, a) for a in args]
            g["req_num"] += 1
            self._sftp._send_packet(cmd, msg)
        return req_num

    def _get_response(self, wantsback=None):
        """
        Read a packet and then process it into a ``Message`` object

        Returns the SSH packet type value of the response, along with
        the response itself wrapped in a Paramiko ``Message`` object

        :param int wantsback: the expected request #
        """
        resp_type, msg = self._increment_response()
        req_num = msg.get_int()
        if req_num == wantsback:
            return resp_type, msg

    def _increment_response(self):
        with self._lock:
            resp_type, data = self._sftp._read_packet()
        return resp_type, Message(data)

    def _threaded_transfer(self, way, to_transfer):
        threads = []
        for xfr in to_transfer:
            t = Thread(target=getattr(self, way), args=(xfr,), kwargs={"thread": True})
            threads.append(t)
            t.start()
        [t.join() for t in threads]

    def _threaded_reader(self, handle, writer, size):
        futures = []
        with self._lock:
            lo["expected_responses"] = math.ceil(size / MAX_PAYLOAD_SIZE)
        with ThreadPoolExecutor() as executor:
            n = 0
            while n < size:
                chunk = min(MAX_PAYLOAD_SIZE, size - n)
                futures.append(executor.submit(self.read, handle, chunk, n, thread=True))
                n += chunk
            requests = [f.result() for f in futures]
        for r in requests:
            resp_type, data = self._sftp._read_packet()
            if resp_type != CMD_DATA:
                raise SFTPError("Expected data")
            msg = Message(data)
            resp_num = msg.get_int()
            if resp_num in _request_stack:
                writer.seek(_request_stack[resp_num][0])
                log.debug(f'write local at byte {_request_stack[resp_num][0]}')
                writer.write(msg.get_string())

    def _threaded_writer(self, handle, reader, size):
        futures = []
        with self._lock:
            lo["expected_responses"] = math.ceil(size / MAX_PAYLOAD_SIZE)
        with ThreadPoolExecutor() as executor:
            pos = 0
            while pos < size:
                data = reader.read(MAX_PAYLOAD_SIZE)
                futures.append(
                    executor.submit(self.write, handle, data, pos, thread=True)
                )
                pos = reader.tell()
        for i in range(0, lo["expected_responses"]):
            resp_type, data = self._sftp._read_packet()
Beispiel #3
0
class TransportTest(unittest.TestCase):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self):
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        self.server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, self.server)
        self.tc.connect(hostkey=public_host_key)
        self.tc.auth_password(username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEquals(type(o), SecurityOptions)
        self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assert_(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assert_(False)
        except TypeError:
            pass

    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
        self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEquals(
            '207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
            hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.assertEquals(None, self.tc.get_username())
        self.assertEquals(None, self.ts.get_username())
        self.assertEquals(False, self.tc.is_authenticated())
        self.assertEquals(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('slowdive', self.tc.get_username())
        self.assertEquals('slowdive', self.ts.get_username())
        self.assertEquals(True, self.tc.is_authenticated())
        self.assertEquals(True, self.ts.is_authenticated())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        options = self.tc.get_security_options()
        options.ciphers = ('aes256-cbc', )
        options.digests = ('hmac-md5-96', )
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('aes256-cbc', self.tc.local_cipher)
        self.assertEquals('aes256-cbc', self.tc.remote_cipher)
        self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
        self.assertEquals(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.tc.set_hexdump(True)

        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

        self.assertEquals(None, getattr(server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEquals('*****@*****.**', server._global_request)

    def test_6_bad_auth_type(self):
        """
        verify that we get the right exception when an unsupported auth
        type is requested.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        try:
            self.tc.connect(hostkey=public_host_key,
                            username='******',
                            password='******')
            self.assert_(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertEquals(BadAuthenticationType, etype)
            self.assertEquals(['publickey'], evalue.allowed_types)

    def test_7_bad_password(self):
        """
        verify that a bad password gets the right exception, and that a retry
        with the right password works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        try:
            self.tc.auth_password(username='******', password='******')
            self.assert_(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assert_(issubclass(etype, SSHException))
        self.tc.auth_password(username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_8_multipart_auth(self):
        """
        verify that multipart auth works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        remain = self.tc.auth_password(username='******',
                                       password='******')
        self.assertEquals(['publickey'], remain)
        key = DSSKey.from_private_key_file('tests/test_dss.key')
        remain = self.tc.auth_publickey(username='******', key=key)
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_9_interactive_auth(self):
        """
        verify keyboard-interactive auth works.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)

        def handler(title, instructions, prompts):
            self.got_title = title
            self.got_instructions = instructions
            self.got_prompts = prompts
            return ['cat']

        remain = self.tc.auth_interactive('commie', handler)
        self.assertEquals(self.got_title, 'password')
        self.assertEquals(self.got_prompts, [('Password', False)])
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_A_interactive_auth_fallback(self):
        """
        verify that a password auth attempt will fallback to "interactive"
        if password auth isn't supported but interactive is.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, server)
        self.tc.ultra_debug = True
        self.tc.connect(hostkey=public_host_key)
        remain = self.tc.auth_password('commie', 'cat')
        self.assertEquals([], remain)
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_B_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command('no')
            self.assert_(False)
        except SSHException, x:
            pass

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('', f.readline())
        f = chan.makefile_stderr()
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())