Ejemplo n.º 1
1
def main():
    if len(sys.argv) < 2:
        print 'Usage: %s [version] [--install] [--local|username password]' % sys.argv[0]
        print 'Where [version] is the branch you want to checkout'
        print 'and username and password are for your eduforge account'
        print 'Eg. %s 0.7 --local' % sys.argv[0]
    else:
        version = sys.argv[1]
        branch = 'http://exe.cfdl.auckland.ac.nz/svn/exe/branches/%s' % version
        origDir = Path(sys.argv[0]).abspath().dirname()
        tmp = TempDirPath()
        os.chdir(tmp)
        os.system('svn export %s exe' % branch)
        (origDir/'../../exe/webui/firefox').copytree(tmp/'exe/exe/webui/firefox')
        os.chdir(tmp/'exe')
        tarball = Path('../exe-%s-source.tgz' % version).abspath()
        os.system('tar czf %s *' % tarball)
        os.chdir(tmp)
        if '--local' not in sys.argv:
            try:
                from paramiko import Transport
            except ImportError:
                print 'To upload you need to install paramiko python library from:'
                print 'http://www.lag.net/paramiko'
                sys.exit(2)
            from socket import socket, gethostbyname
            s = socket()
            s.connect((gethostbyname('shell.eduforge.org'), 22))
            t = Transport(s)
            t.connect()
            t.auth_password(sys.argv[-2], sys.argv[-1])
            f = t.open_sftp_client()
            f.chdir('/home/pub/exe')
            f.put(tarball.encode('utf8'), tarball.basename().encode('utf8'))
        if os.getuid() == 0:
            tarball.copyfile('/usr/portage/distfiles/' + tarball.basename())
        os.chdir(tmp/'exe/installs/gentoo')
        newEbuildFilename = Path('exe-%s.ebuild' % version).abspath()
        if not newEbuildFilename.exists():
            Path('exe-0.7.ebuild').copy(newEbuildFilename)
        if os.getuid() == 0:
            ebuildDir = Path('/usr/local/portage/dev-python/exe')
            if ebuildDir.exists():
                ebuildDir.rmtree()
            ebuildDir.makedirs()
            os.chdir(ebuildDir)
            newEbuildFilename.copy(ebuildDir)
            filesDir = ebuildDir/'files'
            filesDir.makedirs()
            Path(tmp/'exe/installs/gentoo/all-config.patch').copy(filesDir)
            if '--local' not in sys.argv:
                oldTarball = Path('/usr/portage/distfiles/')/tarball.basename()
                if oldTarball.exists():
                    oldTarball.remove()
                os.environ['GENTOO_MIRRORS']=''
                os.system('ebuild %s fetch' % newEbuildFilename.basename())
            os.system('ebuild %s manifest' % newEbuildFilename.basename())
            os.system('ebuild %s digest' % newEbuildFilename.basename())
            if '--install' in sys.argv:
                os.system('ebuild %s install' % newEbuildFilename.basename())
Ejemplo n.º 2
0
class SshConnection:
    """ tsc ssh connection """

    def __init__(self, name, host, port=22, user=None, password=None):
        """ init """

        self.connection = None
        self.session = None
        self.host = host
        self.port = port
        self.loginAccount = {
            'user': user,
            'pass': password
        }
        self.name = name
        self.timeout = 4

        self.nbytes = 4096

        try:
            self.connection =  Transport((self.host, self.port))
            self.connection.connect(username=self.loginAccount['user'],
                                    password=self.loginAccount['pass'])
            self.connection.set_keepalive(self.timeout)
            self.session = self.connection.open_channel(kind='session')

            print('ssh connection created!')
        except Exception, e:
            self.__del__()
            print('ssh connection failed: {er}'.format(er=e))
Ejemplo n.º 3
0
    def write_redirects_to_sftp(self, from_path, to_path, cron):
        try:
            ssh_key_object = RSAKey(filename=app.config['SFTP_SSH_KEY_PATH'],
                                    password=app.config['SFTP_SSH_KEY_PASSPHRASE'])

            remote_server_public_key = HostKeyEntry.from_line(app.config['SFTP_REMOTE_HOST_PUBLIC_KEY']).key
            # This will throw a warning, but the (string, int) tuple will automatically be parsed into a Socket object
            remote_server = Transport((app.config['SFTP_REMOTE_HOST'], 22))
            remote_server.connect(hostkey=remote_server_public_key, username=app.config['SFTP_USERNAME'], pkey=ssh_key_object)

            sftp = SFTPClient.from_transport(remote_server)
            sftp.put(from_path, to_path)
            if cron:
                return 'SFTP publish from %s to %s succeeded' % (from_path, to_path)
            else:
                return fjson.dumps({
                    'type': 'success',
                    'message': 'Redirect updates successful'
                })
        except:
            if cron:
                return 'SFTP publish from %s to %s failed' % (from_path, to_path)
            else:
                return fjson.dumps({
                    'type': 'danger',
                    'message': 'Redirect updates failed'
                })
Ejemplo n.º 4
0
class SFTPCopy:
    def __init__(self, host, port, username, password):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.t = Transport((self.host, self.port))
        self.t.connect(username=self.username, password=self.password)   # 注意:此处需要用默认传参方式传参
        self.sftp = SFTPClient.from_transport(self.t)

    def copyData(self, localdir, remotedir):

        for (root, dirs, files) in walk(localdir):
            for dir in dirs:
                dest = path.join(remotedir, root[len(localdir):][1:], dir)
                try:
                    self.sftp.mkdir(dest)
                except Exception as e:
                    print('ERROR: make directory %s %s' % (dest, e))
            for file in files:
                src = path.join(root, file)
                dest = path.join(remotedir, root[len(localdir):][1:], file)
                try:
                    self.sftp.put(src, dest)
                except Exception as e:
                    print('ERROR: touch file %s %s' % (dest, e))
Ejemplo n.º 5
0
 def scp(self, srcFile, destPath):
     transport = Transport((self.host, int(self.port)))
     transport.connect(username=self.user, password=self.passwd)
     sftp = SFTPClient.from_transport(transport)
     try:
         sftp.put(srcFile, destPath)
     except IOError as e:
         raise e
Ejemplo n.º 6
0
    def run(self):

        self.socket.listen(100)
        while True:
            self.socket.settimeout(15)
            s,addr = self.socket.accept()
            transport = Transport(s)
            transport.add_server_key(self.key)
            event = Event()
            #transport.set_subsystem_handler('', ShellHandler)
            transport.start_server(event, server=self)
Ejemplo n.º 7
0
 def sftpOpenConnection(self, target):
     """ opens an sftp connection to the given target host;
         credentials are taken from /etc/ssh/sftp_passwd """
     from paramiko import Transport, SFTPClient
     from paramiko.util import load_host_keys
     hostname, username, password = self.getCredentials(target)
     hostkeys = load_host_keys('/etc/ssh/ssh_known_hosts')
     hostkeytype, hostkey = hostkeys[hostname].items()[0]
     trans = Transport((hostname, 22))
     trans.connect(username=username, password=password, hostkey=hostkey)
     return SFTPClient.from_transport(trans)
Ejemplo n.º 8
0
    def _createConnection(self):
        """
        @see: L{_createConnection<datafinder.persistence.common.connection.pool.ConnectionPool._createConnection>}
        """

        try:
            connection = Transport((self._configuration.hostname, DEFAULT_SSH_PORT))
            connection.connect(username=self._configuration.username, password=self._configuration.password)
            return connection
        except (SSHException, socket.error, socket.gaierror), error:
            errorMessage = u"Unable to establish SSH connection to TSM host '%s'! " \
                           % (self._configuration.hostname) + "\nReason: '%s'" % str(error)
            raise PersistenceError(errorMessage)
Ejemplo n.º 9
0
 def sftpOpen(self, target):
     """ opens an sftp connection to the given target host;
         credentials are taken from /etc/ssh/sftp_passwd """
     from paramiko import Transport, SFTPClient
     #from paramiko.util import load_host_keys
     print '>>> sftpOpen self.username: '******'             self.password: '******'             self.hostkey: ', self.password
     print '             self.hostname: ', self.hostname
     trans = Transport((self.hostname, 22))
     trans.banner_timeout = 120
     trans.connect(username=self.username,
                   password=self.password,
                   hostkey=self.hostkey)
     return SFTPClient.from_transport(trans)
Ejemplo n.º 10
0
    def __init__(self, host, port, user, pw):
        self.transport = Transport((host, port))
        self._datos = {'host': host, 'port': port, 'user': user, 'pw': pw}
        self.rsa_key = None

        self.setPrivateKey(expanduser('~/.ssh/id_rsa'))
        self.conectar()
Ejemplo n.º 11
0
 def _authenticate(self):
     self._transport = SFTPTransport((self.config['sftp_host'],
                                      self.config['sftp_port']))
     self._transport.connect(username=self.config['sftp_user'],
                             password=self.config['sftp_password'])
     self.session = SFTPClient.from_transport(self._transport)
     logging.info('SFTP Authorization succeed')
Ejemplo n.º 12
0
    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.tcp_server = server

        # Keep track of channel information from the transport
        self.channels = {}

        try:
            self.client = request._sock
        except AttributeError as e:
            self.client = request
        # Transport turns the socket into an SSH transport
        self.transport = Transport(self.client)

        # Create the PTY handler class by mixing in
        TelnetHandlerClass = self.telnet_handler

        class MixedPtyHandler(TelnetToPtyHandler, TelnetHandlerClass):
            # BaseRequestHandler does not inherit from object, must call the __init__ directly
            def __init__(self, *args):
                super(MixedPtyHandler, self).__init__(*args)
                TelnetHandlerClass.__init__(self, *args)

        self.pty_handler = MixedPtyHandler

        # Call the base class to run the handler
        BaseRequestHandler.__init__(self, request, client_address, server)
Ejemplo n.º 13
0
	def sshAuthentication(self, clientsock):
		# setup logging
		paramiko.util.log_to_file(C.SYSLOG_FILE)

		# Check that SSH server parameters have been set:
		if (self.sshData == None):
			return clientsock, False, None
		else:
			# Load private key of the server
			filekey = self.sshData["hostKeyFile"]
			if (not filekey.startswith("/")):
				filekey = C.YENCAP_CONF_HOME + "/" + filekey

			# Build a key object from the file path:
			if (self.sshData["hostKeyType"] == "dss"):
				priv_host_key = paramiko.DSSKey(filename=filekey)
			elif (self.sshData["hostKeyType"] == "rsa"):
				priv_host_key = paramiko.RSAKey(filename=filekey)

		try:
			event = threading.Event()
			# Create a new SSH session over an existing socket, or socket-like object.
			t = Transport(clientsock)
			# Add a host key to the list of keys used for server mode.
			t.add_server_key(priv_host_key)
			# paramiko.ServerInterface defines an interface for controlling the behavior of paramiko in server mode.
			server = SSHServerModule()
			# Negotiate a new SSH2 session as a server.
			t.start_server(event, server)
			while 1:
				event.wait(0.1)
				if not t.is_active():
					return clientsock, False, None
				if event.isSet():
					break
		
			# Return the next channel opened by the client over this transport, in server mode.
			channel = t.accept(20)
			
			if channel is None:
				return clientsock, False, None
		
		except Exception, e:
			LogManager.getInstance().logError("Caught exception: %s: %s" % (str(e.__class__), str(e)))
			traceback.print_exc()

			try:
				t.close()
			except:
				pass
			return clientsock, False, None
Ejemplo n.º 14
0
 def __init__(self, host, port, username, password):
     self.host = host
     self.port = port
     self.username = username
     self.password = password
     self.t = Transport((self.host, self.port))
     self.t.connect(username=self.username, password=self.password)   # 注意:此处需要用默认传参方式传参
     self.sftp = SFTPClient.from_transport(self.t)
Ejemplo n.º 15
0
def listdir(hostname, path="/var/tmp", filter="", port=1035, username="", password=""):
    """
        paramiko sftp listdir wrapper, with option to filter files
    """
    # Paramiko client configuration

    t = Transport((hostname, port))
    t.connect(username=username, password=password)
    sftp = SFTPClient.from_transport(t)

    try:
        rex = re.compile(filter)
    except:
        print "Invalid regular expression: " + filter
        sys.exit(1)

    return [x for x in sftp.listdir(path) if rex.match(x)]
Ejemplo n.º 16
0
 def setup(self):
     self.transport = Transport(self.request)
     self.transport.load_server_moduli()
     so = self.transport.get_security_options()
     so.digests = ('hmac-sha1', )
     so.compression = ('*****@*****.**', 'none')
     self.transport.add_server_key(self.server.host_key)
     self.transport.set_subsystem_handler(
         'sftp', MyTSFTPServer, MyTSFTPServerInterface)
Ejemplo n.º 17
0
 def connect(self, hostname_address):
     """connect to SSH target using paramiko Transport/Channel"""
     self._transport = Transport(hostname_address)
     self._transport.connect(username=self.user, password=self.passwd)
     self.ssh_channel = self._transport.open_channel("session")
     self.pty = self.ssh_channel.get_pty()
     self.ssh_channel.invoke_shell()
     self.wait_for_prompt()
     self.ssh_channel.send("set cli screen-length 0")
Ejemplo n.º 18
0
    def upload_report_to_sftp(self, client_id, report_date, absolute_filename):
        """
        Upload the given file, using SFTP, to the configured FTP server. The
        file should be uploaded to the appropriate directory for the specified
        client and the date of the report.
        """
        try:
            client = Clients.objects.get(id=client_id)
        except Clients.DoesNotExist:
            logger.exception(u'No configuration for client {0}.'.format(client_id))
            raise

        filename = basename(absolute_filename)
        base_folder, env_folder, year_folder, month_folder = self._get_sftp_dirs(client, report_date)

        try:
            logger.debug(u'SFTP logging on to {0} as {1}'.format(settings.SFTP_SERVER, settings.SFTP_USERNAME))
            transport = Transport((settings.SFTP_SERVER, settings.SFTP_PORT))
            transport.connect(username=settings.SFTP_USERNAME, password=settings.SFTP_PASSWORD)
            sftp = SFTPClient.from_transport(transport)

            logger.debug(u'SFTP dir {0}/{1}/{2}/{3}'.format(base_folder, env_folder, year_folder, month_folder))
            sftp.chdir(base_folder)
            self._make_or_change_sftp_dir(sftp, env_folder)
            self._make_or_change_sftp_dir(sftp, year_folder)
            self._make_or_change_sftp_dir(sftp, month_folder)

            logger.debug(u'SFTP uploading {0}'.format(filename))
            sftp.put(absolute_filename, filename)
        except Exception:
            logger.exception(u'Unrecoverable exception during SFTP upload process.')
            raise
        finally:
            logger.debug(u'SFTP logging off')

            try:
                sftp.close()
            except Exception:
                logger.exception(u'SFTP exception while closing SFTP session.')

            try:
                transport.close()
            except Exception:
                logger.exception(u'SFTP exception while closing SSH connection.')
Ejemplo n.º 19
0
    def setCipher(self, cipher):
        self.transport = Transport((self._datos['host'],
                                    self._datos['port']))

        self.transport.get_security_options().ciphers = [cipher, ]

        self.transport.connect(username=self._datos['user'],
                               password=self._datos['pw'])

        self.transport.set_keepalive(60)
Ejemplo n.º 20
0
def connect(username, hostname='lxplus.cern.ch', port=22):
    "Connect to a given host"
    print "Connecting to %s@%s" % (username, hostname)
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((hostname, port))
    except Exception as err:
        print '*** Connect failed: ' + str(err)
        sys.exit(1)
    transport = Transport(sock)
    try:
        transport.start_client()
    except paramiko.SSHException as err:
        print "SSH negotiation failed\n%s" % str(err)

    try:
        keys = paramiko.util.load_host_keys(\
                os.path.expanduser('~/.ssh/known_hosts'))
    except IOError:
        try:
            keys = paramiko.util.load_host_keys(\
                os.path.expanduser('~/ssh/known_hosts'))
        except IOError:
            print '*** Unable to open host keys file'
            keys = {}

    # check server's host key -- this is important.
    key = transport.get_remote_server_key()
    if  not keys.has_key(hostname):
        print '*** WARNING: Unknown host key!'
    elif not keys[hostname].has_key(key.get_name()):
        print '*** WARNING: Unknown host key!'
    elif keys[hostname][key.get_name()] != key:
        print '*** WARNING: Host key has changed!!!'
        sys.exit(1)
    else:
        pass

    # get username
    if  username == '':
        default_username = getpass.getuser()
        username = raw_input('Username [%s]: ' % default_username)
        if  len(username) == 0:
            username = default_username

    agent_auth(transport, username)
    if not transport.is_authenticated():
        manual_auth(transport, username, hostname)
    if not transport.is_authenticated():
        print '*** Authentication failed. :('
        transport.close()
        sys.exit(1)
    return transport, sock
Ejemplo n.º 21
0
def sftp_connect(config):
    try:
        hostname = config["hostname"]
        port = int(config["port"])
        username = config["username"]
        pkeyfile = config["privatekey_file"]
        pkeypassword = config["privatekey_passphrase"]
        host_keys = load_host_keys(config["hostkeys_file"])
        hostkeytype, hostkey = get_hostkeytype_and_hostkey(host_keys,
                                                           config["hostname"])
        t = Transport((hostname, port))
        pkey = get_privatekey_from_file(pkeyfile, pkeypassword) 
        t.connect(username=username, pkey=pkey, hostkey=hostkey)

        return SFTPClient.from_transport(t)

    except SSHException, e:
        print " [*] " + str(e)
        sys.exit(1)
Ejemplo n.º 22
0
class SFTPConnection:
  """
  Handle a SFTP (SSH over FTP) Connection
  """

  def __init__(self, url, user_name, password=None, private_key=None):
    self.url = url
    self.user_name = user_name
    if password and private_key:
      raise SFTPError("Password and private_key cannot be defined simultaneously")
    self.password = password
    self.private_key = private_key

  def connect(self):
    """ Get a handle to a remote connection """
    # Check URL
    schema = urlparse(self.url)
    if schema.scheme == 'sftp':
      self.transport = Transport((schema.hostname, int(schema.port)))
    else:
      raise SFTPError('Not a valid sftp url %s, type is %s' %(self.url, schema.scheme))
    # Add authentication to transport
    try:
      if self.password:
        self.transport.connect(username=self.user_name, password=self.password)
      elif self.private_key:
        self.transport.connect(username=self.user_name,
                               pkey=RSAKey.from_private_key(StringIO(self.private_key)))
      else:
        raise SFTPError("No password or private_key defined")
      # Connect
      self.conn = SFTPClient.from_transport(self.transport)
    except (socket.gaierror,error), msg:
      raise SFTPError(str(msg) + ' while establishing connection to %s' % (self.url,))
    # Go to specified directory
    try:
      schema.path.rstrip('/')
      if len(schema.path):
        self.conn.chdir(schema.path)
    except IOError, msg:
      raise SFTPError(str(msg) + ' while changing to dir -%r-' % (schema.path,))
Ejemplo n.º 23
0
    def _connect(self):
        self._conn = Transport((self._host, self._port))
        self._conn.window_size = pow(2, 27)
        self._conn.packetizer.REKEY_BYTES = pow(2, 32)
        self._conn.packetizer.REKEY_PACKETS = pow(2, 32)
        if self._auth == 'key':
            pkey = RSAKey.from_private_key_file(self._key_path)
            self._conn.connect(username=self._user, pkey=pkey)
        else:
            self._conn.connect(username=self._user, password=self._passwd)

        self._client = SFTPClient.from_transport(self._conn)
Ejemplo n.º 24
0
def main():
    try:
        from paramiko import Transport
    except ImportError:
        print
        print 'To upload you need to install paramiko python library from:'
        print 'http://www.lag.net/paramiko',
        print 'or on ubuntu go: apt-get install python2.4-paramiko'
        print
        sys.exit(2)
    server = 'shell.eduforge.org'
    basedir = '/home/pub/exe/'
    print 'Please enter password for %s@%s:' % (sys.argv[-1], server)
    password = getpass()
    print 'Renaming files'
    install = Path('eXe_install_windows.exe')
    newName = Path('eXe-install-%s.exe' % release)
    install = renameFile(install, newName)
    ready2run = Path('exes.exe')
    newName = Path('eXe-ready2run-%s.exe' % release)
    ready2run = renameFile(ready2run, newName)
    print 'Uploading'
    print 'connecting to %s...' % server
    from socket import socket, gethostbyname
    s = socket()
    s.connect((gethostbyname(server), 22))
    t = Transport(s)
    t.connect()
    t.auth_password(sys.argv[-1], password)
    sftp = t.open_sftp_client()
    sftp.chdir(basedir)
    upFile(sftp, install)
    upFile(sftp, ready2run)
Ejemplo n.º 25
0
 def _connect(self):
     if self._conn is not None:
         log.warn("Already connected to sftp server")
         return
     try:
         self._conn = Transport((self._host, self._port))
         self._conn.window_size = pow(2, 27)
         self._conn.packetizer.REKEY_BYTES = pow(2, 32)
         self._conn.packetizer.REKEY_PACKETS = pow(2, 32)
         self._conn.connect(username=self._user, password=self._passwd)
         self._client = SFTPClient.from_transport(self._conn)
     except Exception as e:
         raise IrmaSFTPError("{0}".format(e))
Ejemplo n.º 26
0
class RemoteFileManager(object):
    def __init__(self, host, user, password, port=22):
        self._host = host
        self._port = port
        self._user = user
        self._pass = password

    def connect(self):
        self._transport = Transport((self._host, self._port))
        self._transport.connect(username=self._user, password=self._pass)
        self._client = SFTPClient.from_transport(self._transport)

    def files(self, path=None):
        path = '' if path is None else path
        return self._client.listdir(path)

    def file_text(self, path):
        return self._client.open(path, 'r').readlines()

    def lock_wait(self, path, sleep_time=3):
        while True:
            try:
                self._client.open(path, 'r')
                print('Waiting for lock on sums file to be released...')
                sleep(sleep_time)
            except Exception:
                ## No lock file present on remote
                break

    def file_iter(self, path):
        try:
            self._client.stat(path)
        except IOError:
            with self._client.open(path, 'w') as remote_file:
                remote_file.close()

        for line in self._client.open(path, 'r'):
            yield line
Ejemplo n.º 27
0
def get_sftp_conn(config):
    """Make a SFTP connection, returns sftp client and connection objects"""
    remote = config.get('remote_location')
    parts = urlparse(remote)

    if ':' in parts.netloc:
        hostname, port = parts.netloc.split(':')
    else:
        hostname = parts.netloc
        port = 22
    port = int(port)

    username = config.get('remote_username') or getuser()
    luser = get_local_user(username)
    sshdir = get_ssh_dir(config, luser)
    hostkey = get_host_keys(hostname, sshdir)

    try:
        sftp = None
        keys = get_ssh_keys(sshdir)
        transport = Transport((hostname, port))
        while not keys.empty():
            try:
                key = PKey.from_private_key_file(keys.get())
                transport.connect(
                    hostkey=hostkey,
                    username=username,
                    password=None,
                    pkey=key)
                sftp = SFTPClient.from_transport(transport)
                break
            except (PasswordRequiredException, SSHException):
                pass
        if sftp is None:
            raise SaChannelUpdateTransportError("SFTP connection failed")
        return sftp, transport
    except BaseException as msg:
        raise SaChannelUpdateTransportError(msg)
Ejemplo n.º 28
0
    def __call__(self):
        registry = getUtility(IRegistry)
        recensio_settings = registry.forInterface(IRecensioSettings)
        host = recensio_settings.xml_export_server
        username = recensio_settings.xml_export_username
        password = recensio_settings.xml_export_password
        if not host:
            return 'no host configured'
        log.info("Starting XML export to sftp")

        exporter = getUtility(IFactory, name='chronicon_exporter')()
        export_xml = exporter.get_export_obj(self.context)
        if export_xml is None:
            msg = "Could not get export file object: {0}".format(
                exporter.export_filename)
            log.error(msg)
            return msg

        zipstream = export_xml.getFile()
        try:
            transport = Transport((host, 22))
            transport.connect(username=username, password=password)
            sftp = SFTPClient.from_transport(transport)
            attribs = sftp.putfo(zipstream.getBlob().open(), self.filename)
        except (IOError, SSHException) as ioe:
            msg = "Export failed, {0}: {1}".format(ioe.__class__.__name__, ioe)
            log.error(msg)
            return msg
        if attribs.st_size == zipstream.get_size():
            msg = "Export successful"
            log.info(msg)
            return msg
        else:
            msg = "Export failed, {0}/{1} bytes transferred".format(
                attribs.st_size, zipstream.get_size())
            log.error(msg)
            return msg
Ejemplo n.º 29
0
 def connect(self):
   """ Get a handle to a remote connection """
   # Check URL
   schema = urlparse(self.url)
   if schema.scheme == 'sftp':
     self.transport = Transport((schema.hostname, int(schema.port)))
   else:
     raise SFTPError('Not a valid sftp url %s, type is %s' %(self.url, schema.scheme))
   # Add authentication to transport
   try:
     if self.password:
       self.transport.connect(username=self.user_name, password=self.password)
     elif self.private_key:
       self.transport.connect(username=self.user_name,
                              pkey=RSAKey.from_private_key(StringIO(self.private_key)))
     else:
       raise SFTPError("No password or private_key defined")
     # Connect
     self.conn = SFTPClient.from_transport(self.transport)
   except (socket.gaierror,error), msg:
     raise SFTPError(str(msg) + ' while establishing connection to %s' % (self.url,))
Ejemplo n.º 30
0
class TransportTest(ParamikoTest):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self, client_options=None, server_options=None):
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)

        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())

        event = threading.Event()
        self.server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, self.server)
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEquals(type(o), SecurityOptions)
        self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assert_(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assert_(False)
        except TypeError:
            pass

    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
        self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEquals(
            '207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
            hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.assertEquals(None, self.tc.get_username())
        self.assertEquals(None, self.ts.get_username())
        self.assertEquals(False, self.tc.is_authenticated())
        self.assertEquals(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('slowdive', self.tc.get_username())
        self.assertEquals('slowdive', self.ts.get_username())
        self.assertEquals(True, self.tc.is_authenticated())
        self.assertEquals(True, self.ts.is_authenticated())

    def test_3a_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******',
                        password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        def force_algorithms(options):
            options.ciphers = ('aes256-cbc', )
            options.digests = ('hmac-md5-96', )

        self.setup_test_server(client_options=force_algorithms)
        self.assertEquals('aes256-cbc', self.tc.local_cipher)
        self.assertEquals('aes256-cbc', self.tc.remote_cipher)
        self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
        self.assertEquals(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEquals(None, getattr(self.server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEquals('*****@*****.**', self.server._global_request)

    def test_6_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command('no')
            self.assert_(False)
        except SSHException, x:
            pass

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('', f.readline())
        f = chan.makefile_stderr()
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())
Ejemplo n.º 31
0
def sftp_client_from_transport_old(hostname, username, password):
    from paramiko import Transport, SFTPClient
    tn = Transport((hostname, 22))
    tn.connect(username=username, password=password)
    return SFTPClient.from_transport(tn)
Ejemplo n.º 32
0
def upload_sftp(server_info, upload_files, delete_files, force_reupload=False):
    server, info = server_info
    t = Transport((info['host'], info['port']))
    try:
        t.connect(username=info['username'], password=info['password'])
    except AuthenticationException as e:
        print('\t\t%s' % e)
        return False, [], []

    plugins_load = []
    plugins_reload = []

    sftp = SFTPClient.from_transport(t)
    try:
        sftp.chdir(info['path'])
    except FileNotFoundError as e:
        print('\t\tCannot chdir to remote path %s: %s' % (info['path'], e))
        return False, [], []

    if delete_files:
        print('\t\tProcessing %d deletions...' % (len(delete_files)))
        for i, d in enumerate(delete_files):
            print('\t\t\t%d. %s' % (i + 1, d))
            try:
                file_info = sftp.stat(info['path'] + '/' + d)
                if stat.S_ISDIR(file_info.st_mode):
                    sftp.rmdir(info['path'] + '/' + d)
                else:
                    sftp.remove(info['path'] + '/' + d)
            except FileNotFoundError:
                pass
            except Exception as e:
                print('\t\t\t\tFailed: %s' % e)

    for (path, file) in upload_files:
        abs_path = info['path'] + '/' + path
        relative_filepath = path + '/' + file
        print('\t\tProcessing %s' % file)

        try:
            sftp.stat(abs_path)
        except FileNotFoundError:
            print('\t\t\t%s does not exist, creating...' % path)

            i = 1
            while i != -1:
                i = path.find('/', i)
                abs_path = info['path'] + path[0:i]
                if i != -1:
                    try:
                        sftp.stat(abs_path)
                        i += 1
                    except FileNotFoundError:
                        sftp.mkdir(abs_path)

            sftp.mkdir(info['path'] + path)

        local_filepath = 'upload' + relative_filepath

        try:

            fileinfo = sftp.stat(info['path'] + relative_filepath)
            print('\t\t\tRemote file found')

            upload = False
            if (force_reupload):
                print('\t\t\tForcing reupload')
                upload = True
            else:
                lsize = os.path.getsize(local_filepath)
                rsize = fileinfo.st_size

                if lsize != rsize:
                    print('\t\t\tSize mismatch -- Reuploading')
                    upload = True

            if upload:
                sftp.put(local_filepath, info['path'] + relative_filepath)

                if (file[-4:] == '.smx'):
                    plugins_reload.append(file[:-4])

        except FileNotFoundError:
            print('\t\t\tUploading')

            sftp.put(local_filepath, info['path'] + relative_filepath)

            if (file[-4:] == '.smx'):
                plugins_load.append(file[:-4])

    t.close()
    print('\t\tLogout')

    return True, plugins_load, plugins_reload
Ejemplo n.º 33
0
 def setUp(self):
     self.socks = LoopSocket()
     self.sockc = LoopSocket()
     self.sockc.link(self.socks)
     self.tc = Transport(self.sockc)
     self.ts = Transport(self.socks)
Ejemplo n.º 34
0
class SSHHandler(ServerInterface, BaseRequestHandler):
    telnet_handler = None
    pty_handler = None
    host_key = None
    username = None

    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.tcp_server = server

        # Keep track of channel information from the transport
        self.channels = {}

        self.client = request._sock
        # Transport turns the socket into an SSH transport
        self.transport = Transport(self.client)

        # Create the PTY handler class by mixing in
        TelnetHandlerClass = self.telnet_handler

        class MixedPtyHandler(TelnetToPtyHandler, TelnetHandlerClass):
            # BaseRequestHandler does not inherit from object, must call the __init__ directly
            def __init__(self, *args):
                TelnetHandlerClass.__init__(self, *args)

        self.pty_handler = MixedPtyHandler

        # Call the base class to run the handler
        BaseRequestHandler.__init__(self, request, client_address, server)

    def setup(self):
        '''Setup the connection.'''
        print 'New request from address %s, port %d' % self.client_address

        try:
            self.transport.load_server_moduli()
        except:
            print '(Failed to load moduli -- gex will be unsupported.)'
            raise
        try:
            self.transport.add_server_key(self.host_key)
        except:
            if self.host_key is None:
                raise NotImplementedError(
                    'Host key not set!  SSHHandler instance must define the host_key parameter.  Try host_key = paramiko_ssh.getRsaKeyFile("server_rsa.key").'
                )

        try:
            # Tell transport to use this object as a server
            print 'Starting SSH server-side negotiation'
            self.transport.start_server(server=self)
        except SSHException, e:
            print('SSH negotiation failed. %s' % e)
            raise

        # Accept any requested channels
        while True:
            channel = self.transport.accept(20)
            if channel is None:
                # check to see if any thread is running
                any_running = False
                for c, thread in self.channels.items():
                    if thread.is_alive():
                        any_running = True
                        break
                if not any_running:
                    break
            else:
                print 'Accepted channel %s' % channel
Ejemplo n.º 35
0
class sessionSSH:
    def __init__(self, agent, user):

        self.hostname = agent.getIp()
        self.username = user.getLogin()
        self.publicKey = agent.getPublicKey()
        self.publicKeyType = agent.getPublicKeyType()
        self.version = agent.getVersion()

        self.privateKeyFile = user.getPrivateKeyFile()
        self.privateKeyType = user.getPrivateKeyType()
        self.password = user.getPassword()

        self.raw_data = ''

        # Create a socket (IPv4 or IPv6):
        if self.version == 4:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        elif self.version == 6:
            sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)

        #sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

        # Connect to the agent (The SSH tunnel is done later)
        sock.connect((self.hostname, C.NETCONF_SSH_PORT))

        # Create a new SSH session over an existing socket (here sock).
        self.ssh = Transport(sock)

        self.i = 1

    def connect(self):

        try:

            # Build a public key object from the server (agent) key file
            if self.publicKeyType == 'rsa':
                agent_public_key = RSAKey(
                    data=base64.decodestring(self.publicKey))
            elif self.publicKeyType == 'dss':
                agent_public_key = DSSKey(
                    data=base64.decodestring(self.publicKey))

            # Build a private key object from the manager key file, and connect to the agent:
            if self.privateKeyFile != None:
                # Using client (manager) private key to authenticate
                if self.privateKeyType == "rsa":
                    user_private_key = RSAKey.from_private_key_file(
                        self.privateKeyFile)
                elif self.privateKeyType == "dss":
                    user_private_key = DSSKey.from_private_key_file(
                        self.privateKeyFile)
                self.ssh.connect(hostkey=agent_public_key,
                                 username=self.username,
                                 pkey=user_private_key)
            else:
                # Using client (manager) password to authenticate
                self.ssh.connect(hostkey=agent_public_key,
                                 username=self.username,
                                 password=self.password)

            # Request a new channel to the server, of type "session".
            self.chan = self.ssh.open_session()

            # Request a "netconf" subsystem on the server:
            self.chan.invoke_subsystem(C.NETCONF_SSH_SUBSYSTEM)

        except Exception, exp:
            syslog.openlog("YencaP Manager")
            syslog.syslog(syslog.LOG_ERR, str(exp))
            syslog.closelog()
            return C.FAILED

        return C.SUCCESS
Ejemplo n.º 36
0
class Session:

    CIPHERS = None

    def __init__(self, proxyserver, client_socket, client_address, authenticator, remoteaddr):

        self._transport = None

        self.channel = None

        self.proxyserver = proxyserver
        self.client_socket = client_socket
        self.client_address = client_address
        self.name = "{fr}->{to}".format(fr=client_address, to=remoteaddr)

        self.agent_requested = threading.Event()

        self.ssh = False
        self.ssh_channel = None
        self.ssh_client = None
        self.ssh_pty_kwargs = None

        self.scp = False
        self.scp_channel = None
        self.scp_command = ''

        self.sftp = False
        self.sftp_channel = None
        self.sftp_client = None
        self.sftp_client_ready = threading.Event()

        self.username = ''
        self.socket_remote_address = remoteaddr
        self.remote_address = (None, None)
        self.key = None
        self.agent = None
        self.authenticator = authenticator(self)

    @property
    def running(self):
        # Using status of main channels to determine session status (-> releasability of resources)
        # - often calculated, cpu heavy (?)
        ch_active = all([not ch.closed for ch in filter(None, [self.ssh_channel, self.scp_channel, self.sftp_channel])])
        return self.proxyserver.running and ch_active

    @property
    def transport(self):
        if not self._transport:
            self._transport = Transport(self.client_socket)
            cve202014145.hookup_cve_2020_14145(self)
            if self.CIPHERS:
                if not isinstance(self.CIPHERS, tuple):
                    raise ValueError('ciphers must be a tuple')
                self._transport.get_security_options().ciphers = self.CIPHERS
            self._transport.add_server_key(self.proxyserver.host_key)
            self._transport.set_subsystem_handler('sftp', ProxySFTPServer, self.proxyserver.sftp_interface)

        return self._transport

    def _start_channels(self):
        # create client or master channel
        if self.ssh_client:
            self.sftp_client_ready.set()
            return True

        if not self.agent and (self.authenticator.REQUEST_AGENT or self.authenticator.REQUEST_AGENT_BREAKIN):
            try:
                if self.agent_requested.wait(1) or self.authenticator.REQUEST_AGENT_BREAKIN:
                    self.agent = AgentProxy(self.transport)
            except ChannelException:
                logging.error("Breakin not successful! Closing ssh connection to client")
                self.agent = None
                self.close()
                return False
        # Connect method start
        if not self.agent:
            logging.error('no ssh agent forwarded')
            return False

        if self.authenticator.authenticate() != AUTH_SUCCESSFUL:
            logging.error('Permission denied (publickey)')
            return False

        # Connect method end
        if not self.scp and not self.ssh and not self.sftp:
            if self.transport.is_active():
                self.transport.close()
                return False

        self.sftp_client_ready.set()
        return True

    def start(self):
        event = threading.Event()
        self.transport.start_server(
            event=event,
            server=self.proxyserver.authentication_interface(self)
        )

        while not self.channel:
            self.channel = self.transport.accept(0.5)
            if not self.running:
                self.transport.close()
                return False

        if not self.channel:
            logging.error('(%s) session error opening channel!', self)
            self.transport.close()
            return False

        # wait for authentication
        event.wait()

        if not self.transport.is_active():
            return False

        if not self._start_channels():
            return False

        logging.debug("(%s) session started", self)
        return True

    def close(self):
        if self.agent:
            self.agent.close()
            logging.debug("(%s) session agent cleaned up", self)
        if self.ssh_client:
            logging.debug("(%s) closing ssh client to remote", self)
            self.ssh_client.transport.close()
            # With graceful exit the completion_event can be polled to wait, well ..., for completion
            # it can also only be a graceful exit if the ssh client has already been established
            if self.transport.completion_event.is_set() and self.transport.is_active():
                self.transport.completion_event.clear()
                while self.transport.is_active():
                    if self.transport.completion_event.wait(0.1):
                        break
        self.transport.close()
        logging.debug("(%s) session closed", self)

    def __str__(self):
        return self.name

    def __enter__(self):
        return self

    def __exit__(self, value_type, value, traceback):
        self.close()
Ejemplo n.º 37
0
class SFTPStore(Store):
    """implements the sftp:// storage backend

        configuration via openssh/sftp style urls and
        .ssh/config files

        does not support password authentication or password
        protected authentication keys"""
    def __init__(self, url, **kw):
        if self.netloc.find('@') != -1:
            user, self.netloc = self.netloc.split('@')
        else:
            user = None

        self.config = SSHHostConfig(self.netloc, user)

        host_keys = paramiko.util.load_host_keys(
            os.path.expanduser('~/.ssh/known_hosts'))
        try:
            self.hostkey = list(
                host_keys[self.config['hostkeyalias']].values())[0]
        except:
            print(str(self.config))
            raise

        if ('identityfile' in self.config):
            key_file = os.path.expanduser(self.config['identityfile'])
            #not really nice but i don't see a cleaner way atm...
            try:
                self.auth_key = RSAKey(key_file)
            except SSHException as e:
                if e.message == 'Unable to parse file':
                    self.auth_key = DSAKey(key_file)
                else:
                    raise
        else:
            filename = os.path.expanduser('~/.ssh/id_rsa')
            if os.path.exists(filename):
                self.auth_key = RSAKey(filename)
            else:
                filename = os.path.expanduser('~/.ssh/id_dsa')
                if (os.path.exists(filename)):
                    self.auth_key = DSSKey(filename)

        self.__connect()

    def __connect(self):
        self.t = Transport((self.config['hostname'], self.config['port']))
        self.t.connect(username=self.config['user'], pkey=self.auth_key)
        self.client = SFTPClient.from_transport(self.t)
        self.client.chdir(self.path)

    def __build_fn(self, name):
        return "%s/%s" % (self.path, name)

    def list(self, type):
        return list(
            filter(type_patterns[type].match, self.client.listdir(self.path)))

    def get(self, type, name):
        return self.client.open(self.__build_fn(name), mode='rb')

    def put(self, type, name, fp):
        remote_file = self.client.open(self.__build_fn(name), mode='wb')
        buf = fp.read(4096)
        while (len(buf) > 0):
            remote_file.write(buf)
            buf = fp.read(4096)
        remote_file.close()

    def delete(self, type, name):
        self.client.remove(self.__build_fn(name))

    def stat(self, type, name):
        try:
            stat = self.client.stat(self.__build_fn(name))
            return {'size': stat.st_size}
        except IOError:
            raise NotFoundError

    def close(self):
        """connection has to be explicitly closed, otherwise
            it will hold the process running idefinitly"""
        self.client.close()
        self.t.close()
Ejemplo n.º 38
0
                pw = value[i]
            elif (key[i].upper() == 'PATH'):
                path = value[i]
            elif (key[i].upper() == 'LOOKBACK'):
                lookback = value[i]
    except:
        log.warning('Could not initialize variables')
        FEWSConnect.logFileToXML(logFile, runInfo['outputDiagnosticFile'])
        sys.exit()

# Check if host site is available to retrieve MARFC files from sftp.

    port = int(port)
    try:
        # paramiko.util.log_to_file(tempDir + '\\paramiko.log')
        transport = Transport((host, port))
        transport.connect(username=usr, password=pw)
        sftp = SFTPClient.from_transport(transport)
        log.info('Succesfully connected to %s', host)
        print('Successfully connected to ' + host)
    except:
        log.warning('Could not connect to %s', host)
        print('could not connect')
        print format_exc()

# defines number of hours (files) to retrieve based on lookback value (in days)
# This is defined very simply and assumes MARFC updates a new MPE file for every hour
# as was observed when this script was written.
    hrsback = int(lookback) * 24
    # Retrieve list of files from sftp site
    s = sftp.listdir(path)
Ejemplo n.º 39
0
 def test_preferred_lists_default_to_private_attribute_contents(self):
     t = Transport(sock=Mock())
     assert t.preferred_ciphers == t._preferred_ciphers
     assert t.preferred_macs == t._preferred_macs
     assert t.preferred_keys == t._preferred_keys
     assert t.preferred_kex == t._preferred_kex
Ejemplo n.º 40
0
	def create_connection(self, host, port, username, password, key):
		transport = Transport(sock=(host, port))
		transport.connect(username=username, pkey = key)
		self._connection = SFTPClient.from_transport(transport)
Ejemplo n.º 41
0
class ParamikoSshConnection(BaseSshConnection):
    def connect(self, wait_prompt=True):
        self.socket = socket(AF_INET, SOCK_STREAM)
        self.socket.connect((self.hostname, self.port))
        self.session = Transport(self.socket)
        self.session.start_client()
        if self.password is not None:
            self.session.auth_password(self.username, self.password)
        elif self.key_algorithm != DSA_KEY_ALGORITHM:
            key = RSAKey.from_private_key_file(
                self.private_key_file, self.key_passphrase
            )
            self.session.auth_publickey(self.username, key)
        else:
            key = DSSKey.from_private_key_file(
                self.private_key_file, self.key_passphrase
            )
            self.session.auth_publickey(self.username, key)

        self.channel = self.session.open_session()
        self.channel.get_pty()
        self.channel.invoke_shell()
        if wait_prompt:
            self.receive()

    @property
    def connected(self):
        return bool(
            self.socket
            and not self.socket._closed
            and self.session
            and self.channel
        )

    def send(self, line, socket_timeout=None):
        socket_timeout = (
            socket_timeout
            if socket_timeout is not None
            else self.socket_timeout
        )
        self.channel.settimeout(socket_timeout)
        size = self.channel.sendall(line + "\n")
        return size

    def receive(
        self, regex=None, socket_timeout=None, timeout=None, buffer_size=None
    ):
        regex = regex if regex is not None else self.prompt_regex
        socket_timeout = (
            socket_timeout
            if socket_timeout is not None
            else self.socket_timeout
        )
        timeout = timeout if timeout is not None else self.timeout
        buffer_size = (
            buffer_size if buffer_size is not None else self.buffer_size
        )

        assert regex is not None
        assert socket_timeout is None or isinstance(
            socket_timeout, (int, float)
        )
        assert timeout is None or isinstance(timeout, (int, float))
        assert isinstance(buffer_size, int) and buffer_size > 0

        self.channel.settimeout(socket_timeout)
        start = time()
        output = self.channel.recv(buffer_size).decode()
        LOG.debug(output)
        size = len(output)
        duration = time() - start
        while (
            not regex.search(output)
            and (timeout is None or duration < timeout)
            and size > 0
        ):
            data = self.channel.recv(buffer_size).decode()
            LOG.debug(data)
            size = len(data)
            output += data
            duration = time() - start

        if size < 0 and size != LIBSSH2_ERROR_EAGAIN:
            raise ReceiveException(size, output, duration)

        if not size:
            raise SocketTimeoutException(
                output, socket_timeout, duration, regex.pattern
            )

        if timeout is not None and duration >= timeout:
            raise ReceiveTimeoutException(
                output, timeout, duration, regex.pattern
            )

        return self.sanitize(output)

    def disconnect(self):
        if self.session:
            self.session.close()

        if self.channel:
            self.channel.close()

        if self.socket:
            self.socket.close()
Ejemplo n.º 42
0
class TransportTest(unittest.TestCase):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(
        self, client_options=None, server_options=None, connect_kwargs=None
    ):
        host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)

        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())

        event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not event.is_set())
        self.ts.start_server(event, self.server)
        if connect_kwargs is None:
            connect_kwargs = dict(
                hostkey=public_host_key,
                username="******",
                password="******",
            )
        self.tc.connect(**connect_kwargs)
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_security_options(self):
        o = self.tc.get_security_options()
        self.assertEqual(type(o), SecurityOptions)
        self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers)
        o.ciphers = ("aes256-cbc", "blowfish-cbc")
        self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers)
        try:
            o.ciphers = ("aes256-cbc", "made-up-cipher")
            self.assertTrue(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assertTrue(False)
        except TypeError:
            pass

    def testb_security_options_reset(self):
        o = self.tc.get_security_options()
        # should not throw any exceptions
        o.ciphers = o.ciphers
        o.digests = o.digests
        o.key_types = o.key_types
        o.kex = o.kex
        o.compression = o.compression

    def test_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929  # noqa
        self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3"  # noqa
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key("C", 32)
        self.assertEqual(
            b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995",  # noqa
            hexlify(key).upper(),
        )

    def test_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.assertEqual(None, self.tc.get_username())
        self.assertEqual(None, self.ts.get_username())
        self.assertEqual(False, self.tc.is_authenticated())
        self.assertEqual(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(
            hostkey=public_host_key, username="******", password="******"
        )
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())
        self.assertEqual("slowdive", self.tc.get_username())
        self.assertEqual("slowdive", self.ts.get_username())
        self.assertEqual(True, self.tc.is_authenticated())
        self.assertEqual(True, self.ts.is_authenticated())

    def testa_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(
            hostkey=public_host_key, username="******", password="******"
        )
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """

        def force_algorithms(options):
            options.ciphers = ("aes256-cbc",)
            options.digests = ("hmac-md5-96",)

        self.setup_test_server(client_options=force_algorithms)
        self.assertEqual("aes256-cbc", self.tc.local_cipher)
        self.assertEqual("aes256-cbc", self.tc.remote_cipher)
        self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
        self.assertEqual(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    @slow
    def test_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEqual(None, getattr(self.server, "_global_request", None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEqual("*****@*****.**", self.server._global_request)

    def test_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command(
                b"command contains \xfc and is not a valid UTF-8 string"
            )
            self.assertTrue(False)
        except SSHException:
            pass

        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)
        schan.send("Hello there.\n")
        schan.send_stderr("This is on stderr.\n")
        schan.close()

        f = chan.makefile()
        self.assertEqual("Hello there.\n", f.readline())
        self.assertEqual("", f.readline())
        f = chan.makefile_stderr()
        self.assertEqual("This is on stderr.\n", f.readline())
        self.assertEqual("", f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)
        schan.send("Hello there.\n")
        schan.send_stderr("This is on stderr.\n")
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEqual("Hello there.\n", f.readline())
        self.assertEqual("This is on stderr.\n", f.readline())
        self.assertEqual("", f.readline())

    def testa_channel_can_be_used_as_context_manager(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        with self.tc.open_session() as chan:
            with self.ts.accept(1.0) as schan:
                chan.exec_command("yes")
                schan.send("Hello there.\n")
                schan.close()

                f = chan.makefile()
                self.assertEqual("Hello there.\n", f.readline())
                self.assertEqual("", f.readline())

    def test_invoke_shell(self):
        """
        verify that invoke_shell() does something reasonable.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        chan.send("communist j. cat\n")
        f = schan.makefile()
        self.assertEqual("communist j. cat\n", f.readline())
        chan.close()
        self.assertEqual("", f.readline())

    def test_channel_exception(self):
        """
        verify that ChannelException is thrown for a bad open-channel request.
        """
        self.setup_test_server()
        try:
            self.tc.open_channel("bogus")
            self.fail("expected exception")
        except ChannelException as e:
            self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)

    def test_exit_status(self):
        """
        verify that get_exit_status() works.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        chan.exec_command("yes")
        schan.send("Hello there.\n")
        self.assertTrue(not chan.exit_status_ready())
        # trigger an EOF
        schan.shutdown_read()
        schan.shutdown_write()
        schan.send_exit_status(23)
        schan.close()

        f = chan.makefile()
        self.assertEqual("Hello there.\n", f.readline())
        self.assertEqual("", f.readline())
        count = 0
        while not chan.exit_status_ready():
            time.sleep(0.1)
            count += 1
            if count > 50:
                raise Exception("timeout")
        self.assertEqual(23, chan.recv_exit_status())
        chan.close()

    def test_select(self):
        """
        verify that select() on a channel works.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send("hello\n")

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b"hello\n", chan.recv(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()

        # detect eof?
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)
        self.assertEqual(bytes(), chan.recv(16))

        # make sure the pipe is still open for now...
        p = chan._pipe
        self.assertEqual(False, p._closed)
        chan.close()
        # ...and now is closed.
        self.assertEqual(True, p._closed)

    def test_renegotiate(self):
        """
        verify that a transport can correctly renegotiate mid-stream.
        """
        self.setup_test_server()
        self.tc.packetizer.REKEY_BYTES = 16384
        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)

        self.assertEqual(self.tc.H, self.tc.session_id)
        for i in range(20):
            chan.send("x" * 1024)
        chan.close()

        # allow a few seconds for the rekeying to complete
        for i in range(50):
            if self.tc.H != self.tc.session_id:
                break
            time.sleep(0.1)
        self.assertNotEqual(self.tc.H, self.tc.session_id)

        schan.close()

    def test_compression(self):
        """
        verify that zlib compression is basically working.
        """

        def force_compression(o):
            o.compression = ("zlib",)

        self.setup_test_server(force_compression, force_compression)
        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)

        bytes = self.tc.packetizer._Packetizer__sent_bytes
        chan.send("x" * 1024)
        bytes2 = self.tc.packetizer._Packetizer__sent_bytes
        block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"]
        mac_size = self.tc._mac_info[self.tc.local_mac]["size"]
        # tests show this is actually compressed to *52 bytes*!  including
        # packet overhead!  nice!! :)
        self.assertTrue(bytes2 - bytes < 1024)
        self.assertEqual(16 + block_size + mac_size, bytes2 - bytes)

        chan.close()
        schan.close()

    def test_x11(self):
        """
        verify that an x11 port can be requested and opened.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)

        requested = []

        def handler(c, addr_port):
            addr, port = addr_port
            requested.append((addr, port))
            self.tc._queue_incoming_channel(c)

        self.assertEqual(
            None, getattr(self.server, "_x11_screen_number", None)
        )
        cookie = chan.request_x11(0, single_connection=True, handler=handler)
        self.assertEqual(0, self.server._x11_screen_number)
        self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol)
        self.assertEqual(cookie, self.server._x11_auth_cookie)
        self.assertEqual(True, self.server._x11_single_connection)

        x11_server = self.ts.open_x11_channel(("localhost", 6093))
        x11_client = self.tc.accept()
        self.assertEqual("localhost", requested[0][0])
        self.assertEqual(6093, requested[0][1])

        x11_server.send("hello")
        self.assertEqual(b"hello", x11_client.recv(5))

        x11_server.close()
        x11_client.close()
        chan.close()
        schan.close()

    def test_reverse_port_forwarding(self):
        """
        verify that a client can ask the server to open a reverse port for
        forwarding.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command("yes")
        self.ts.accept(1.0)

        requested = []

        def handler(c, origin_addr_port, server_addr_port):
            requested.append(origin_addr_port)
            requested.append(server_addr_port)
            self.tc._queue_incoming_channel(c)

        port = self.tc.request_port_forward("127.0.0.1", 0, handler)
        self.assertEqual(port, self.server._listen.getsockname()[1])

        cs = socket.socket()
        cs.connect(("127.0.0.1", port))
        ss, _ = self.server._listen.accept()
        sch = self.ts.open_forwarded_tcpip_channel(
            ss.getsockname(), ss.getpeername()
        )
        cch = self.tc.accept()

        sch.send("hello")
        self.assertEqual(b"hello", cch.recv(5))
        sch.close()
        cch.close()
        ss.close()
        cs.close()

        # now cancel it.
        self.tc.cancel_port_forward("127.0.0.1", port)
        self.assertTrue(self.server._listen is None)

    def test_port_forwarding(self):
        """
        verify that a client can forward new connections from a locally-
        forwarded port.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command("yes")
        self.ts.accept(1.0)

        # open a port on the "server" that the client will ask to forward to.
        greeting_server = socket.socket()
        greeting_server.bind(("127.0.0.1", 0))
        greeting_server.listen(1)
        greeting_port = greeting_server.getsockname()[1]

        cs = self.tc.open_channel(
            "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000)
        )
        sch = self.ts.accept(1.0)
        cch = socket.socket()
        cch.connect(self.server._tcpip_dest)

        ss, _ = greeting_server.accept()
        ss.send(b"Hello!\n")
        ss.close()
        sch.send(cch.recv(8192))
        sch.close()

        self.assertEqual(b"Hello!\n", cs.recv(7))
        cs.close()

    def test_stderr_select(self):
        """
        verify that select() on a channel works even if only stderr is
        receiving data.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send_stderr("hello\n")

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b"hello\n", chan.recv_stderr(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()
        chan.close()

    def test_send_ready(self):
        """
        verify that send_ready() indicates when a send would not block.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        self.assertEqual(chan.send_ready(), True)
        total = 0
        K = "*" * 1024
        limit = 1 + (64 * 2 ** 15)
        while total < limit:
            chan.send(K)
            total += len(K)
            if not chan.send_ready():
                break
        self.assertTrue(total < limit)

        schan.close()
        chan.close()
        self.assertEqual(chan.send_ready(), True)

    def test_rekey_deadlock(self):
        """
        Regression test for deadlock when in-transit messages are received
        after MSG_KEXINIT is sent

        Note: When this test fails, it may leak threads.
        """

        # Test for an obscure deadlocking bug that can occur if we receive
        # certain messages while initiating a key exchange.
        #
        # The deadlock occurs as follows:
        #
        # In the main thread:
        #   1. The user's program calls Channel.send(), which sends
        #      MSG_CHANNEL_DATA to the remote host.
        #   2. Packetizer discovers that REKEY_BYTES has been exceeded, and
        #      sets the __need_rekey flag.
        #
        # In the Transport thread:
        #   3. Packetizer notices that the __need_rekey flag is set, and raises
        #      NeedRekeyException.
        #   4. In response to NeedRekeyException, the transport thread sends
        #      MSG_KEXINIT to the remote host.
        #
        # On the remote host (using any SSH implementation):
        #   5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST
        #      is sent.
        #   6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is
        #      sent.
        #
        # In the main thread:
        #   7. The user's program calls Channel.send().
        #   8. Channel.send acquires Channel.lock, then calls
        #      Transport._send_user_message().
        #   9. Transport._send_user_message waits for Transport.clear_to_send
        #      to be set (i.e., it waits for re-keying to complete).
        #      Channel.lock is still held.
        #
        # In the Transport thread:
        #   10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust
        #       is called to handle it.
        #   11. Channel._window_adjust tries to acquire Channel.lock, but it
        #       blocks because the lock is already held by the main thread.
        #
        # The result is that the Transport thread never processes the remote
        # host's MSG_KEXINIT packet, because it becomes deadlocked while
        # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message.

        # We set up two separate threads for sending and receiving packets,
        # while the main thread acts as a watchdog timer.  If the timer
        # expires, a deadlock is assumed.

        class SendThread(threading.Thread):
            def __init__(self, chan, iterations, done_event):
                threading.Thread.__init__(
                    self, None, None, self.__class__.__name__
                )
                self.setDaemon(True)
                self.chan = chan
                self.iterations = iterations
                self.done_event = done_event
                self.watchdog_event = threading.Event()
                self.last = None

            def run(self):
                try:
                    for i in range(1, 1 + self.iterations):
                        if self.done_event.is_set():
                            break
                        self.watchdog_event.set()
                        # print i, "SEND"
                        self.chan.send("x" * 2048)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        class ReceiveThread(threading.Thread):
            def __init__(self, chan, done_event):
                threading.Thread.__init__(
                    self, None, None, self.__class__.__name__
                )
                self.setDaemon(True)
                self.chan = chan
                self.done_event = done_event
                self.watchdog_event = threading.Event()

            def run(self):
                try:
                    while not self.done_event.is_set():
                        if self.chan.recv_ready():
                            chan.recv(65536)
                            self.watchdog_event.set()
                        else:
                            if random.randint(0, 1):
                                time.sleep(random.randint(0, 500) / 1000.0)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        self.setup_test_server()
        self.ts.packetizer.REKEY_BYTES = 2048

        chan = self.tc.open_session()
        chan.exec_command("yes")
        schan = self.ts.accept(1.0)

        # Monkey patch the client's Transport._handler_table so that the client
        # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
        # MSG_KEXINIT.  This is used to simulate the effect of network latency
        # on a real MSG_CHANNEL_WINDOW_ADJUST message.
        self.tc._handler_table = (
            self.tc._handler_table.copy()
        )  # copy per-class dictionary
        _negotiate_keys = self.tc._handler_table[MSG_KEXINIT]

        def _negotiate_keys_wrapper(self, m):
            if self.local_kex_init is None:  # Remote side sent KEXINIT
                # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
                # before responding to the incoming MSG_KEXINIT.
                m2 = Message()
                m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
                m2.add_int(chan.remote_chanid)
                m2.add_int(1)  # bytes to add
                self._send_message(m2)
            return _negotiate_keys(self, m)

        self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper

        # Parameters for the test
        iterations = 500  # The deadlock does not happen every time, but it
        # should after many iterations.
        timeout = 5

        # This event is set when the test is completed
        done_event = threading.Event()

        # Start the sending thread
        st = SendThread(schan, iterations, done_event)
        st.start()

        # Start the receiving thread
        rt = ReceiveThread(chan, done_event)
        rt.start()

        # Act as a watchdog timer, checking
        deadlocked = False
        while not deadlocked and not done_event.is_set():
            for event in (st.watchdog_event, rt.watchdog_event):
                event.wait(timeout)
                if done_event.is_set():
                    break
                if not event.is_set():
                    deadlocked = True
                    break
                event.clear()

        # Tell the threads to stop (if they haven't already stopped).  Note
        # that if one or more threads are deadlocked, they might hang around
        # forever (until the process exits).
        done_event.set()

        # Assertion: We must not have detected a timeout.
        self.assertFalse(deadlocked)

        # Close the channels
        schan.close()
        chan.close()

    def test_sanitze_packet_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [
            (4095, MIN_PACKET_SIZE),
            (None, DEFAULT_MAX_PACKET_SIZE),
            (2 ** 32, MAX_WINDOW_SIZE),
        ]:
            self.assertEqual(self.tc._sanitize_packet_size(val), correct)

    def test_sanitze_window_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [
            (32767, MIN_WINDOW_SIZE),
            (None, DEFAULT_WINDOW_SIZE),
            (2 ** 32, MAX_WINDOW_SIZE),
        ]:
            self.assertEqual(self.tc._sanitize_window_size(val), correct)

    @slow
    def test_handshake_timeout(self):
        """
        verify that we can get a hanshake timeout.
        """
        # Tweak client Transport instance's Packetizer instance so
        # its read_message() sleeps a bit. This helps prevent race conditions
        # where the client Transport's timeout timer thread doesn't even have
        # time to get scheduled before the main client thread finishes
        # handshaking with the server.
        # (Doing this on the server's transport *sounds* more 'correct' but
        # actually doesn't work nearly as well for whatever reason.)
        class SlowPacketizer(Packetizer):
            def read_message(self):
                time.sleep(1)
                return super(SlowPacketizer, self).read_message()

        # NOTE: prettttty sure since the replaced .packetizer Packetizer is now
        # no longer doing anything with its copy of the socket...everything'll
        # be fine. Even tho it's a bit squicky.
        self.tc.packetizer = SlowPacketizer(self.tc.sock)
        # Continue with regular test red tape.
        host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.tc.handshake_timeout = 0.000000000001
        self.ts.start_server(event, server)
        self.assertRaises(
            EOFError,
            self.tc.connect,
            hostkey=public_host_key,
            username="******",
            password="******",
        )

    def test_select_after_close(self):
        """
        verify that select works when a channel is already closed.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        schan.close()

        # give client a moment to receive close notification
        time.sleep(0.1)

        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

    def test_channel_send_misc(self):
        """
        verify behaviours sending various instances to a channel
        """
        self.setup_test_server()
        text = u"\xa7 slice me nicely"
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # TypeError raised on non string or buffer type
            self.assertRaises(TypeError, chan.send, object())
            self.assertRaises(TypeError, chan.sendall, object())

            # sendall() accepts a unicode instance
            chan.sendall(text)
            expected = text.encode("utf-8")
            self.assertEqual(sfile.read(len(expected)), expected)

    @needs_builtin("buffer")
    def test_channel_send_buffer(self):
        """
        verify sending buffer instances to a channel
        """
        self.setup_test_server()
        data = 3 * b"some test data\n whole"
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts buffer instances
            sent = 0
            while sent < len(data):
                sent += chan.send(buffer(data, sent, 8))  # noqa
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a buffer instance
            chan.sendall(buffer(data))  # noqa
            self.assertEqual(sfile.read(len(data)), data)

    @needs_builtin("memoryview")
    def test_channel_send_memoryview(self):
        """
        verify sending memoryview instances to a channel
        """
        self.setup_test_server()
        data = 3 * b"some test data\n whole"
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts memoryview slices
            sent = 0
            view = memoryview(data)
            while sent < len(view):
                sent += chan.send(view[sent : sent + 8])
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a memoryview instance
            chan.sendall(memoryview(data))
            self.assertEqual(sfile.read(len(data)), data)

    def test_server_rejects_open_channel_without_auth(self):
        try:
            self.setup_test_server(connect_kwargs={})
            self.tc.open_session()
        except ChannelException as e:
            assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
        else:
            assert False, "Did not raise ChannelException!"

    def test_server_rejects_arbitrary_global_request_without_auth(self):
        self.setup_test_server(connect_kwargs={})
        # NOTE: this dummy global request kind would normally pass muster
        # from the test server.
        self.tc.global_request("acceptable")
        # Global requests never raise exceptions, even on failure (not sure why
        # this was the original design...ugh.) Best we can do to tell failure
        # happened is that the client transport's global_response was set back
        # to None; if it had succeeded, it would be the response Message.
        err = "Unauthed global response incorrectly succeeded!"
        assert self.tc.global_response is None, err

    def test_server_rejects_port_forward_without_auth(self):
        # NOTE: at protocol level port forward requests are treated same as a
        # regular global request, but Paramiko server implements a special-case
        # method for it, so it gets its own test. (plus, THAT actually raises
        # an exception on the client side, unlike the general case...)
        self.setup_test_server(connect_kwargs={})
        try:
            self.tc.request_port_forward("localhost", 1234)
        except SSHException as e:
            assert "forwarding request denied" in str(e)
        else:
            assert False, "Did not raise SSHException!"

    def _send_unimplemented(self, server_is_sender):
        self.setup_test_server()
        sender, recipient = self.tc, self.ts
        if server_is_sender:
            sender, recipient = self.ts, self.tc
        recipient._send_message = Mock()
        msg = Message()
        msg.add_byte(cMSG_UNIMPLEMENTED)
        sender._send_message(msg)
        # TODO: I hate this but I literally don't see a good way to know when
        # the recipient has received the sender's message (there are no
        # existing threading events in play that work for this), esp in this
        # case where we don't WANT a response (as otherwise we could
        # potentially try blocking on the sender's receipt of a reply...maybe).
        time.sleep(0.1)
        assert not recipient._send_message.called

    def test_server_does_not_respond_to_MSG_UNIMPLEMENTED(self):
        self._send_unimplemented(server_is_sender=False)

    def test_client_does_not_respond_to_MSG_UNIMPLEMENTED(self):
        self._send_unimplemented(server_is_sender=True)

    def _send_client_message(self, message_type):
        self.setup_test_server(connect_kwargs={})
        self.ts._send_message = Mock()
        # NOTE: this isn't 100% realistic (most of these message types would
        # have actual other fields in 'em) but it suffices to test the level of
        # message dispatch we're interested in here.
        msg = Message()
        # TODO: really not liking the whole cMSG_XXX vs MSG_XXX duality right
        # now, esp since the former is almost always just byte_chr(the
        # latter)...but since that's the case...
        msg.add_byte(byte_chr(message_type))
        self.tc._send_message(msg)
        # No good way to actually wait for server action (see above tests re:
        # MSG_UNIMPLEMENTED). Grump.
        time.sleep(0.1)

    def _expect_unimplemented(self):
        # Ensure MSG_UNIMPLEMENTED was sent (implies it hit end of loop instead
        # of truly handling the given message).
        # NOTE: When bug present, this will actually be the first thing that
        # fails (since in many cases actual message handling doesn't involve
        # sending a message back right away).
        assert self.ts._send_message.call_count == 1
        reply = self.ts._send_message.call_args[0][0]
        reply.rewind()  # Because it's pre-send, not post-receive
        assert reply.get_byte() == cMSG_UNIMPLEMENTED

    def test_server_transports_reject_client_message_types(self):
        # TODO: handle Transport's own tables too, not just its inner auth
        # handler's table. See TODOs in auth_handler.py
        for message_type in AuthHandler._client_handler_table:
            self._send_client_message(message_type)
            self._expect_unimplemented()
            # Reset for rest of loop
            self.tearDown()
            self.setUp()

    def test_server_rejects_client_MSG_USERAUTH_SUCCESS(self):
        self._send_client_message(MSG_USERAUTH_SUCCESS)
        # Sanity checks
        assert not self.ts.authenticated
        assert not self.ts.auth_handler.authenticated
        # Real fix's behavior
        self._expect_unimplemented()
Ejemplo n.º 43
0
class SSHHandler(ServerInterface, BaseRequestHandler):
    telnet_handler = None
    pty_handler = None
    host_key = None
    username = None

    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.tcp_server = server

        # Keep track of channel information from the transport
        self.channels = {}

        self.client = request._sock
        # Transport turns the socket into an SSH transport
        self.transport = Transport(self.client)

        # Create the PTY handler class by mixing in
        TelnetHandlerClass = self.telnet_handler

        class MixedPtyHandler(TelnetToPtyHandler, TelnetHandlerClass):
            # BaseRequestHandler does not inherit from object, must call the __init__ directly
            def __init__(self, *args):
                TelnetHandlerClass.__init__(self, *args)

        self.pty_handler = MixedPtyHandler

        # Call the base class to run the handler
        BaseRequestHandler.__init__(self, request, client_address, server)

    def setup(self):
        '''Setup the connection.'''
        log.debug('New request from address %s, port %d', self.client_address)

        try:
            self.transport.load_server_moduli()
        except:
            log.exception(
                '(Failed to load moduli -- gex will be unsupported.)')
            raise
        try:
            self.transport.add_server_key(self.host_key)
        except:
            if self.host_key is None:
                log.critical(
                    'Host key not set!  SSHHandler MUST define the host_key parameter.'
                )
                raise NotImplementedError(
                    'Host key not set!  SSHHandler instance must define the host_key parameter.  Try host_key = paramiko_ssh.getRsaKeyFile("server_rsa.key").'
                )

        try:
            # Tell transport to use this object as a server
            log.debug('Starting SSH server-side negotiation')
            self.transport.start_server(server=self)
        except SSHException as e:
            log.warn('SSH negotiation failed. %s', e)
            raise

        # Accept any requested channels
        while True:
            channel = self.transport.accept(20)
            if channel is None:
                # check to see if any thread is running
                any_running = False
                for c, thread in list(self.channels.items()):
                    if thread.is_alive():
                        any_running = True
                        break
                if not any_running:
                    break
            else:
                log.info('Accepted channel %s', channel)
                #raise RuntimeError('No channel requested.')

    class dummy_request(object):
        def __init__(self):
            self._sock = None

    @classmethod
    def streamserver_handle(cls, socket, address):
        '''Translate this class for use in a StreamServer'''
        request = cls.dummy_request()
        request._sock = socket
        server = None
        cls(request, address, server)

    def finish(self):
        '''Called when the socket closes from the client.'''
        self.transport.close()

    def check_channel_request(self, kind, chanid):
        if kind == 'session':
            return OPEN_SUCCEEDED
        return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def set_username(self, username):
        self.username = username
        log.info('User logged in: %s' % username)

    ######  Handle User Authentication ######

    # Override these with functions to use for callbacks
    authCallback = None
    authCallbackKey = None
    authCallbackUsername = None

    def get_allowed_auths(self, username):
        methods = []
        if self.authCallbackUsername is not None:
            methods.append('none')
        if self.authCallback is not None:
            methods.append('password')
        if self.authCallbackKey is not None:
            methods.append('publickey')

        if methods == []:
            # If no methods were defined, use none
            methods.append('none')

        log.debug('Configured authentication methods: %r', methods)
        return ','.join(methods)

    def check_auth_password(self, username, password):
        #print 'check_auth_password(%s, %s)' % (username, password)
        try:
            self.authCallback(username, password)
        except:
            return AUTH_FAILED
        else:
            self.set_username(username)
            return AUTH_SUCCESSFUL

    def check_auth_publickey(self, username, key):
        #print 'Auth attempt with key: ' + hexlify(key.get_fingerprint())
        try:
            self.authCallbackKey(username, key)
        except:
            return AUTH_FAILED
        else:
            self.set_username(username)
            return AUTH_SUCCESSFUL
        #if (username == 'xx') and (key == self.good_pub_key):
        #    return AUTH_SUCCESSFUL

    def check_auth_none(self, username):
        if self.authCallbackUsername is None:
            self.set_username(username)
            return AUTH_SUCCESSFUL
        try:
            self.authCallbackUsername(username)
        except:
            return AUTH_FAILED
        else:
            self.set_username(username)
            return AUTH_SUCCESSFUL

    def check_channel_shell_request(self, channel):
        '''Request to start a shell on the given channel'''
        try:
            self.channels[channel].start()
        except KeyError:
            log.error(
                'Requested to start a channel (%r) that was not previously set up.',
                channel)
            return False
        else:
            return True

    def check_channel_pty_request(self, channel, term, width, height,
                                  pixelwidth, pixelheight, modes):
        '''Request to allocate a PTY terminal.'''
        #self.sshterm = term
        #print "term: %r, modes: %r" % (term, modes)
        log.debug('PTY requested.  Setting up %r.', self.telnet_handler)
        pty_thread = Thread(target=self.start_pty_request,
                            args=(channel, term, modes))
        self.channels[channel] = pty_thread

        return True

    def start_pty_request(self, channel, term, modes):
        '''Start a PTY - intended to run it a (green)thread.'''
        request = self.dummy_request()
        request._sock = channel
        request.modes = modes
        request.term = term
        request.username = self.username

        # modes = http://www.ietf.org/rfc/rfc4254.txt page 18
        # for i in xrange(50):
        #    print "%r: %r" % (int(m[i*5].encode('hex'), 16), int(''.join(m[i*5+1:i*5+5]).encode('hex'), 16))

        # This should block until the user quits the pty
        self.pty_handler(request, self.client_address, self.tcp_server)

        # Shutdown the entire session
        self.transport.close()
Ejemplo n.º 44
0
def sftp_server():
    """
    Set up an in-memory SFTP server thread. Yields the client Transport/socket.

    The resulting client Transport (along with all the server components) will
    be the same object throughout the test session; the `sftp` fixture then
    creates new higher level client objects wrapped around the client
    Transport, as necessary.
    """
    # Sockets & transports
    socks = LoopSocket()
    sockc = LoopSocket()
    sockc.link(socks)
    tc = Transport(sockc)
    ts = Transport(socks)
    # Auth
    host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
    ts.add_server_key(host_key)
    # Server setup
    event = threading.Event()
    server = StubServer()
    ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer)
    ts.start_server(event, server)
    # Wait (so client has time to connect? Not sure. Old.)
    event.wait(1.0)
    # Make & yield connection.
    tc.connect(username="******", password="******")
    yield tc
Ejemplo n.º 45
0
class AuthTest (unittest.TestCase):

    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def start_server(self):
        host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
        self.public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        self.event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not self.event.is_set())
        self.ts.start_server(self.event, self.server)

    def verify_finished(self):
        self.event.wait(1.0)
        self.assertTrue(self.event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_1_bad_auth_type(self):
        """
        verify that we get the right exception when an unsupported auth
        type is requested.
        """
        self.start_server()
        try:
            self.tc.connect(hostkey=self.public_host_key,
                            username='******', password='******')
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertEqual(BadAuthenticationType, etype)
            self.assertEqual(['publickey'], evalue.allowed_types)

    def test_2_bad_password(self):
        """
        verify that a bad password gets the right exception, and that a retry
        with the right password works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            self.tc.auth_password(username='******', password='******')
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
        self.tc.auth_password(username='******', password='******')
        self.verify_finished()

    def test_3_multipart_auth(self):
        """
        verify that multipart auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password(username='******', password='******')
        self.assertEqual(['publickey'], remain)
        key = DSSKey.from_private_key_file(test_path('test_dss.key'))
        remain = self.tc.auth_publickey(username='******', key=key)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_4_interactive_auth(self):
        """
        verify keyboard-interactive auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)

        def handler(title, instructions, prompts):
            self.got_title = title
            self.got_instructions = instructions
            self.got_prompts = prompts
            return ['cat']
        remain = self.tc.auth_interactive('commie', handler)
        self.assertEqual(self.got_title, 'password')
        self.assertEqual(self.got_prompts, [('Password', False)])
        self.assertEqual([], remain)
        self.verify_finished()

    def test_5_interactive_auth_fallback(self):
        """
        verify that a password auth attempt will fallback to "interactive"
        if password auth isn't supported but interactive is.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('commie', 'cat')
        self.assertEqual([], remain)
        self.verify_finished()

    def test_6_auth_utf8(self):
        """
        verify that utf-8 encoding happens in authentication.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('utf8', _pwd)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_7_auth_non_utf8(self):
        """
        verify that non-utf-8 encoded passwords can be used for broken
        servers.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('non-utf8', '\xff')
        self.assertEqual([], remain)
        self.verify_finished()

    def test_8_auth_gets_disconnected(self):
        """
        verify that we catch a server disconnecting during auth, and report
        it as an auth failure.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            remain = self.tc.auth_password('bad-server', 'hello')
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
Ejemplo n.º 46
0
class SFTPConnection:
    """
  Handle a SFTP (SSH over FTP) Connection
  """
    def __init__(self,
                 url,
                 user_name,
                 password=None,
                 private_key=None,
                 bind_address=None):
        self.url = url
        self.user_name = user_name
        if password and private_key:
            raise SFTPError(
                "Password and private_key cannot be defined simultaneously")
        self.password = password
        self.private_key = private_key
        self.bind_address = bind_address

    def connect(self):
        """ Get a handle to a remote connection """
        # Check URL
        schema = urlparse(self.url)
        if schema.scheme == 'sftp':
            hostname = schema.hostname
            port = int(schema.port)
            # Socket creation code inspired from paramiko.Transport.__init__
            # with added bind support.
            for family, socktype, _, _, _ in getaddrinfo(
                    hostname,
                    port,
                    AF_UNSPEC,
                    SOCK_STREAM,
            ):
                if socktype == SOCK_STREAM:
                    sock = socket(family, SOCK_STREAM)
                    if self.bind_address:
                        # XXX: Expects bind address to be of same family as hostname.
                        # May not be easy if name resolution is involved.
                        # Try to reconciliate them ?
                        sock.bind((self.bind_address, 0))
                    retry_on_signal(lambda: sock.connect((hostname, port)))
                    break
            else:
                raise SFTPError('No suitable socket family found')
            self.transport = Transport(sock)
        else:
            raise SFTPError('Not a valid sftp url %s, type is %s' %
                            (self.url, schema.scheme))
        # Add authentication to transport
        try:
            if self.password:
                self.transport.connect(username=self.user_name,
                                       password=self.password)
            elif self.private_key:
                self.transport.connect(username=self.user_name,
                                       pkey=RSAKey.from_private_key(
                                           StringIO(self.private_key)))
            else:
                raise SFTPError("No password or private_key defined")
            # Connect
            self.conn = SFTPClient.from_transport(self.transport)
        except (gaierror, error), msg:
            raise SFTPError(
                str(msg) + ' while establishing connection to %s' %
                (self.url, ))
        # Go to specified directory
        try:
            schema.path.rstrip('/')
            if len(schema.path):
                self.conn.chdir(schema.path)
        except IOError, msg:
            raise SFTPError(
                str(msg) + ' while changing to dir -%r-' % (schema.path, ))
Ejemplo n.º 47
0
class Session:
    CIPHERS = None

    def __init__(self, proxyserver, client_socket, client_address,
                 authenticator, remoteaddr):

        self._transport = None

        self.channel = None

        self.proxyserver = proxyserver
        self.client_socket = client_socket
        self.client_address = client_address

        self.ssh = False
        self.ssh_channel = None
        self.ssh_client = None

        self.scp = False
        self.scp_channel = None
        self.scp_command = ''

        self.sftp = False
        self.sftp_channel = None
        self.sftp_client = None
        self.sftp_client_ready = threading.Event()

        self.username = ''
        self.socket_remote_address = remoteaddr
        self.remote_address = (None, None)
        self.key = None
        self.agent = None
        self.authenticator = authenticator(self)

    @property
    def running(self):
        return self.proxyserver.running

    @property
    def transport(self):
        if not self._transport:
            self._transport = Transport(self.client_socket)
            if self.CIPHERS:
                if not isinstance(self.CIPHERS, tuple):
                    raise ValueError('ciphers must be a tuple')
                self._transport.get_security_options().ciphers = self.CIPHERS
            self._transport.add_server_key(self.proxyserver.host_key)
            self._transport.set_subsystem_handler(
                'sftp', ProxySFTPServer, self.proxyserver.sftp_interface)

        return self._transport

    def _start_channels(self):
        # create client or master channel
        if self.ssh_client:
            self.sftp_client_ready.set()
            return True

        if not self.agent and self.authenticator.AGENT_FORWARDING:
            try:
                self.agent = AgentServerProxy(self.transport)
                self.agent.connect()
            except Exception:
                self.close()
                return False
        # Connect method start
        if not self.agent:
            self.channel.send('Kein SSH Agent weitergeleitet\r\n')
            return False

        if self.authenticator.authenticate() != AUTH_SUCCESSFUL:
            self.channel.send('Permission denied (publickey).\r\n')
            return False
        logging.info('connection established')

        # Connect method end
        if not self.scp and not self.ssh and not self.sftp:
            if self.transport.is_active():
                self.transport.close()
                return False

        self.sftp_client_ready.set()
        return True

    def start(self):
        event = threading.Event()
        self.transport.start_server(
            event=event,
            server=self.proxyserver.authentication_interface(self))

        while not self.channel:
            self.channel = self.transport.accept(0.5)
            if not self.running:
                if self.transport.is_active():
                    self.transport.close()
                return False

        if not self.channel:
            logging.error('error opening channel!')
            if self.transport.is_active():
                self.transport.close()
            return False

        # wait for authentication
        event.wait()

        if not self.transport.is_active():
            return False

        if not self._start_channels():
            return False

        logging.info("session started")
        return True

    def close(self):
        if self.transport.is_active():
            self.transport.close()
        if self.agent:
            self.agent.close()

    def __enter__(self):
        return self

    def __exit__(self, value_type, value, traceback):
        self.close()
Ejemplo n.º 48
0
class TransportTest(unittest.TestCase):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self, client_options=None, server_options=None):
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)

        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())

        event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not event.is_set())
        self.ts.start_server(event, self.server)
        self.tc.connect(username='******', password='******')
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEqual(type(o), SecurityOptions)
        self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assertTrue(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assertTrue(False)
        except TypeError:
            pass

    def test_1b_security_options_reset(self):
        o = self.tc.get_security_options()
        # should not throw any exceptions
        o.ciphers = o.ciphers
        o.digests = o.digests
        o.key_types = o.key_types
        o.kex = o.kex
        o.compression = o.compression

    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
        self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEqual(
            b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
            hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.assertEqual(None, self.tc.get_username())
        self.assertEqual(None, self.ts.get_username())
        self.assertEqual(False, self.tc.is_authenticated())
        self.assertEqual(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(username='******', password='******')
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())
        self.assertEqual('slowdive', self.tc.get_username())
        self.assertEqual('slowdive', self.ts.get_username())
        self.assertEqual(True, self.tc.is_authenticated())
        self.assertEqual(True, self.ts.is_authenticated())

    def test_3a_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(username='******', password='******')
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        def force_algorithms(options):
            options.ciphers = ('aes256-cbc', )
            options.digests = ('hmac-md5-96', )

        self.setup_test_server(client_options=force_algorithms)
        self.assertEqual('aes256-cbc', self.tc.local_cipher)
        self.assertEqual('aes256-cbc', self.tc.remote_cipher)
        self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
        self.assertEqual(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    @slow
    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEqual(None, getattr(self.server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEqual('*****@*****.**', self.server._global_request)

    def test_6_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command(
                b'command contains \xfc and is not a valid UTF-8 string')
            self.assertTrue(False)
        except SSHException:
            pass

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('', f.readline())
        f = chan.makefile_stderr()
        self.assertEqual('This is on stderr.\n', f.readline())
        self.assertEqual('', f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('This is on stderr.\n', f.readline())
        self.assertEqual('', f.readline())

    def test_6a_channel_can_be_used_as_context_manager(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        with self.tc.open_session() as chan:
            with self.ts.accept(1.0) as schan:
                chan.exec_command('yes')
                schan.send('Hello there.\n')
                schan.close()

                f = chan.makefile()
                self.assertEqual('Hello there.\n', f.readline())
                self.assertEqual('', f.readline())

    def test_7_invoke_shell(self):
        """
        verify that invoke_shell() does something reasonable.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        chan.send('communist j. cat\n')
        f = schan.makefile()
        self.assertEqual('communist j. cat\n', f.readline())
        chan.close()
        self.assertEqual('', f.readline())

    def test_8_channel_exception(self):
        """
        verify that ChannelException is thrown for a bad open-channel request.
        """
        self.setup_test_server()
        try:
            chan = self.tc.open_channel('bogus')
            self.fail('expected exception')
        except ChannelException as e:
            self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)

    def test_9_exit_status(self):
        """
        verify that get_exit_status() works.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        chan.exec_command('yes')
        schan.send('Hello there.\n')
        self.assertTrue(not chan.exit_status_ready())
        # trigger an EOF
        schan.shutdown_read()
        schan.shutdown_write()
        schan.send_exit_status(23)
        schan.close()

        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('', f.readline())
        count = 0
        while not chan.exit_status_ready():
            time.sleep(0.1)
            count += 1
            if count > 50:
                raise Exception("timeout")
        self.assertEqual(23, chan.recv_exit_status())
        chan.close()

    def test_A_select(self):
        """
        verify that select() on a channel works.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send('hello\n')

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b'hello\n', chan.recv(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()

        # detect eof?
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)
        self.assertEqual(bytes(), chan.recv(16))

        # make sure the pipe is still open for now...
        p = chan._pipe
        self.assertEqual(False, p._closed)
        chan.close()
        # ...and now is closed.
        self.assertEqual(True, p._closed)

    def test_B_renegotiate(self):
        """
        verify that a transport can correctly renegotiate mid-stream.
        """
        self.setup_test_server()
        self.tc.packetizer.REKEY_BYTES = 16384
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        self.assertEqual(self.tc.H, self.tc.session_id)
        for i in range(20):
            chan.send('x' * 1024)
        chan.close()

        # allow a few seconds for the rekeying to complete
        for i in range(50):
            if self.tc.H != self.tc.session_id:
                break
            time.sleep(0.1)
        self.assertNotEqual(self.tc.H, self.tc.session_id)

        schan.close()

    def test_C_compression(self):
        """
        verify that zlib compression is basically working.
        """
        def force_compression(o):
            o.compression = ('zlib', )

        self.setup_test_server(force_compression, force_compression)
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        bytes = self.tc.packetizer._Packetizer__sent_bytes
        chan.send('x' * 1024)
        bytes2 = self.tc.packetizer._Packetizer__sent_bytes
        block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size']
        mac_size = self.tc._mac_info[self.tc.local_mac]['size']
        # tests show this is actually compressed to *52 bytes*!  including packet overhead!  nice!! :)
        self.assertTrue(bytes2 - bytes < 1024)
        self.assertEqual(16 + block_size + mac_size, bytes2 - bytes)

        chan.close()
        schan.close()

    def test_D_x11(self):
        """
        verify that an x11 port can be requested and opened.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        requested = []

        def handler(c, addr_port):
            addr, port = addr_port
            requested.append((addr, port))
            self.tc._queue_incoming_channel(c)

        self.assertEqual(None, getattr(self.server, '_x11_screen_number',
                                       None))
        cookie = chan.request_x11(0, single_connection=True, handler=handler)
        self.assertEqual(0, self.server._x11_screen_number)
        self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
        self.assertEqual(cookie, self.server._x11_auth_cookie)
        self.assertEqual(True, self.server._x11_single_connection)

        x11_server = self.ts.open_x11_channel(('localhost', 6093))
        x11_client = self.tc.accept()
        self.assertEqual('localhost', requested[0][0])
        self.assertEqual(6093, requested[0][1])

        x11_server.send('hello')
        self.assertEqual(b'hello', x11_client.recv(5))

        x11_server.close()
        x11_client.close()
        chan.close()
        schan.close()

    def test_E_reverse_port_forwarding(self):
        """
        verify that a client can ask the server to open a reverse port for
        forwarding.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        requested = []

        def handler(c, origin_addr_port, server_addr_port):
            requested.append(origin_addr_port)
            requested.append(server_addr_port)
            self.tc._queue_incoming_channel(c)

        port = self.tc.request_port_forward('127.0.0.1', 0, handler)
        self.assertEqual(port, self.server._listen.getsockname()[1])

        cs = socket.socket()
        cs.connect(('127.0.0.1', port))
        ss, _ = self.server._listen.accept()
        sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(),
                                                   ss.getpeername())
        cch = self.tc.accept()

        sch.send('hello')
        self.assertEqual(b'hello', cch.recv(5))
        sch.close()
        cch.close()
        ss.close()
        cs.close()

        # now cancel it.
        self.tc.cancel_port_forward('127.0.0.1', port)
        self.assertTrue(self.server._listen is None)

    def test_F_port_forwarding(self):
        """
        verify that a client can forward new connections from a locally-
        forwarded port.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        # open a port on the "server" that the client will ask to forward to.
        greeting_server = socket.socket()
        greeting_server.bind(('127.0.0.1', 0))
        greeting_server.listen(1)
        greeting_port = greeting_server.getsockname()[1]

        cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port),
                                  ('', 9000))
        sch = self.ts.accept(1.0)
        cch = socket.socket()
        cch.connect(self.server._tcpip_dest)

        ss, _ = greeting_server.accept()
        ss.send(b'Hello!\n')
        ss.close()
        sch.send(cch.recv(8192))
        sch.close()

        self.assertEqual(b'Hello!\n', cs.recv(7))
        cs.close()

    def test_G_stderr_select(self):
        """
        verify that select() on a channel works even if only stderr is
        receiving data.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send_stderr('hello\n')

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b'hello\n', chan.recv_stderr(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()
        chan.close()

    def test_H_send_ready(self):
        """
        verify that send_ready() indicates when a send would not block.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        self.assertEqual(chan.send_ready(), True)
        total = 0
        K = '*' * 1024
        limit = 1 + (64 * 2**15)
        while total < limit:
            chan.send(K)
            total += len(K)
            if not chan.send_ready():
                break
        self.assertTrue(total < limit)

        schan.close()
        chan.close()
        self.assertEqual(chan.send_ready(), True)

    def test_I_rekey_deadlock(self):
        """
        Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent

        Note: When this test fails, it may leak threads.
        """

        # Test for an obscure deadlocking bug that can occur if we receive
        # certain messages while initiating a key exchange.
        #
        # The deadlock occurs as follows:
        #
        # In the main thread:
        #   1. The user's program calls Channel.send(), which sends
        #      MSG_CHANNEL_DATA to the remote host.
        #   2. Packetizer discovers that REKEY_BYTES has been exceeded, and
        #      sets the __need_rekey flag.
        #
        # In the Transport thread:
        #   3. Packetizer notices that the __need_rekey flag is set, and raises
        #      NeedRekeyException.
        #   4. In response to NeedRekeyException, the transport thread sends
        #      MSG_KEXINIT to the remote host.
        #
        # On the remote host (using any SSH implementation):
        #   5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent.
        #   6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent.
        #
        # In the main thread:
        #   7. The user's program calls Channel.send().
        #   8. Channel.send acquires Channel.lock, then calls Transport._send_user_message().
        #   9. Transport._send_user_message waits for Transport.clear_to_send
        #      to be set (i.e., it waits for re-keying to complete).
        #      Channel.lock is still held.
        #
        # In the Transport thread:
        #   10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust
        #       is called to handle it.
        #   11. Channel._window_adjust tries to acquire Channel.lock, but it
        #       blocks because the lock is already held by the main thread.
        #
        # The result is that the Transport thread never processes the remote
        # host's MSG_KEXINIT packet, because it becomes deadlocked while
        # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message.

        # We set up two separate threads for sending and receiving packets,
        # while the main thread acts as a watchdog timer.  If the timer
        # expires, a deadlock is assumed.

        class SendThread(threading.Thread):
            def __init__(self, chan, iterations, done_event):
                threading.Thread.__init__(self, None, None,
                                          self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.iterations = iterations
                self.done_event = done_event
                self.watchdog_event = threading.Event()
                self.last = None

            def run(self):
                try:
                    for i in range(1, 1 + self.iterations):
                        if self.done_event.is_set():
                            break
                        self.watchdog_event.set()
                        #print i, "SEND"
                        self.chan.send("x" * 2048)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        class ReceiveThread(threading.Thread):
            def __init__(self, chan, done_event):
                threading.Thread.__init__(self, None, None,
                                          self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.done_event = done_event
                self.watchdog_event = threading.Event()

            def run(self):
                try:
                    while not self.done_event.is_set():
                        if self.chan.recv_ready():
                            chan.recv(65536)
                            self.watchdog_event.set()
                        else:
                            if random.randint(0, 1):
                                time.sleep(random.randint(0, 500) / 1000.0)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        self.setup_test_server()
        self.ts.packetizer.REKEY_BYTES = 2048

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        # Monkey patch the client's Transport._handler_table so that the client
        # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
        # MSG_KEXINIT.  This is used to simulate the effect of network latency
        # on a real MSG_CHANNEL_WINDOW_ADJUST message.
        self.tc._handler_table = self.tc._handler_table.copy(
        )  # copy per-class dictionary
        _negotiate_keys = self.tc._handler_table[MSG_KEXINIT]

        def _negotiate_keys_wrapper(self, m):
            if self.local_kex_init is None:  # Remote side sent KEXINIT
                # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
                # before responding to the incoming MSG_KEXINIT.
                m2 = Message()
                m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
                m2.add_int(chan.remote_chanid)
                m2.add_int(1)  # bytes to add
                self._send_message(m2)
            return _negotiate_keys(self, m)

        self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper

        # Parameters for the test
        iterations = 500  # The deadlock does not happen every time, but it
        # should after many iterations.
        timeout = 5

        # This event is set when the test is completed
        done_event = threading.Event()

        # Start the sending thread
        st = SendThread(schan, iterations, done_event)
        st.start()

        # Start the receiving thread
        rt = ReceiveThread(chan, done_event)
        rt.start()

        # Act as a watchdog timer, checking
        deadlocked = False
        while not deadlocked and not done_event.is_set():
            for event in (st.watchdog_event, rt.watchdog_event):
                event.wait(timeout)
                if done_event.is_set():
                    break
                if not event.is_set():
                    deadlocked = True
                    break
                event.clear()

        # Tell the threads to stop (if they haven't already stopped).  Note
        # that if one or more threads are deadlocked, they might hang around
        # forever (until the process exits).
        done_event.set()

        # Assertion: We must not have detected a timeout.
        self.assertFalse(deadlocked)

        # Close the channels
        schan.close()
        chan.close()

    def test_J_sanitze_packet_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [(4095, MIN_PACKET_SIZE),
                             (None, DEFAULT_MAX_PACKET_SIZE),
                             (2**32, MAX_WINDOW_SIZE)]:
            self.assertEqual(self.tc._sanitize_packet_size(val), correct)

    def test_K_sanitze_window_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [(32767, MIN_WINDOW_SIZE),
                             (None, DEFAULT_WINDOW_SIZE),
                             (2**32, MAX_WINDOW_SIZE)]:
            self.assertEqual(self.tc._sanitize_window_size(val), correct)

    @slow
    def test_L_handshake_timeout(self):
        """
        verify that we can get a hanshake timeout.
        """

        # Tweak client Transport instance's Packetizer instance so
        # its read_message() sleeps a bit. This helps prevent race conditions
        # where the client Transport's timeout timer thread doesn't even have
        # time to get scheduled before the main client thread finishes
        # handshaking with the server.
        # (Doing this on the server's transport *sounds* more 'correct' but
        # actually doesn't work nearly as well for whatever reason.)
        class SlowPacketizer(Packetizer):
            def read_message(self):
                time.sleep(1)
                return super(SlowPacketizer, self).read_message()

        # NOTE: prettttty sure since the replaced .packetizer Packetizer is now
        # no longer doing anything with its copy of the socket...everything'll
        # be fine. Even tho it's a bit squicky.
        self.tc.packetizer = SlowPacketizer(self.tc.sock)
        # Continue with regular test red tape.
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.tc.handshake_timeout = 0.000000000001
        self.ts.start_server(event, server)
        self.assertRaises(EOFError,
                          self.tc.connect,
                          username='******',
                          password='******')

    def test_M_select_after_close(self):
        """
        verify that select works when a channel is already closed.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        schan.close()

        # give client a moment to receive close notification
        time.sleep(0.1)

        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

    def test_channel_send_misc(self):
        """
        verify behaviours sending various instances to a channel
        """
        self.setup_test_server()
        text = u"\xa7 slice me nicely"
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # TypeError raised on non string or buffer type
            self.assertRaises(TypeError, chan.send, object())
            self.assertRaises(TypeError, chan.sendall, object())

            # sendall() accepts a unicode instance
            chan.sendall(text)
            expected = text.encode("utf-8")
            self.assertEqual(sfile.read(len(expected)), expected)

    @needs_builtin('buffer')
    def test_channel_send_buffer(self):
        """
        verify sending buffer instances to a channel
        """
        self.setup_test_server()
        data = 3 * b'some test data\n whole'
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts buffer instances
            sent = 0
            while sent < len(data):
                sent += chan.send(buffer(data, sent, 8))
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a buffer instance
            chan.sendall(buffer(data))
            self.assertEqual(sfile.read(len(data)), data)

    @needs_builtin('memoryview')
    def test_channel_send_memoryview(self):
        """
        verify sending memoryview instances to a channel
        """
        self.setup_test_server()
        data = 3 * b'some test data\n whole'
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts memoryview slices
            sent = 0
            view = memoryview(data)
            while sent < len(view):
                sent += chan.send(view[sent:sent + 8])
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a memoryview instance
            chan.sendall(memoryview(data))
            self.assertEqual(sfile.read(len(data)), data)
Ejemplo n.º 49
0
class SSH2NetSessionParamiko:
    def __init__(self, p_self):
        """
        Initialize SSH2NetSessionParamiko Object

        This object, through composition, allows for using Paramiko as the underlying "driver"
        for SSH2Net instead of the default "ssh2-python". Paramiko will be ever so slightly slower
        but as you will most likely be I/O constrained it shouldn't matter! "ssh2-python" as of
        20 October 2019 has a bug preventing keyboard interactive authentication from working as
        desired; this is the reason Paramiko is in here now!

        Args:
            p_self: SSH2Net object

        Returns:
            N/A  # noqa

        Raises:
            N/A  # noqa

        """
        self.__dict__ = p_self.__dict__
        self._session_alive = p_self._session_alive
        self._session_open = p_self._session_open
        self._channel_alive = p_self._channel_alive

    def _session_open_connect(self) -> None:
        """
        Perform session handshake for paramiko (instead of default ssh2-python)

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            RequirementsNotSatisfied: if paramiko is not installed
            Exception: catch all for unknown exceptions during session handshake

        """
        try:
            from paramiko import Transport  # noqa
        except ModuleNotFoundError as exc:
            err = f"Module '{exc.name}' not installed!"
            msg = f"***** {err} {'*' * (80 - len(err))}"
            fix = (
                f"To resolve this issue, install '{exc.name}'. You can do this in one of the "
                "following ways:\n"
                "1: 'pip install -r requirements-paramiko.txt'\n"
                "2: 'pip install ssh2net[paramiko]'")
            warning = "\n" + msg + "\n" + fix + "\n" + msg
            warnings.warn(warning)
            raise RequirementsNotSatisfied
        try:
            self.session = Transport(self.sock)
            self.session.start_client()
            self.session.set_timeout = self._set_timeout
        except Exception as exc:
            logging.critical(
                f"Failed to complete handshake with host {self.host}; "
                f"Exception: {exc}")
            raise exc

    def _session_public_key_auth(self) -> None:
        """
        Perform public key based auth on SSH2NetSession

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            Exception: catch all for unhandled exceptions

        """
        try:
            self.session.auth_publickey(self.auth_user, self.auth_public_key)
        except AuthenticationException:
            logging.critical(
                f"Public key authentication with host {self.host} failed.")
        except Exception as exc:
            logging.critical(
                "Unknown error occurred during public key authentication with host "
                f"{self.host}; Exception: {exc}")
            raise exc

    def _session_password_auth(self) -> None:
        """
        Perform password or keyboard interactive based auth on SSH2NetSession

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            AuthenticationFailed: if authentication fails
            Exception: catch all for unknown other exceptions

        """
        try:
            self.session.auth_password(self.auth_user, self.auth_password)
        except AuthenticationException as exc:
            logging.critical(
                f"Password authentication with host {self.host} failed. Exception: {exc}."
                "\n\tNote: Paramiko automatically attempts both standard auth as well as keyboard "
                "interactive auth. Paramiko exception about bad auth type may be misleading!"
            )
            raise AuthenticationFailed
        except Exception as exc:
            logging.critical(
                "Unknown error occurred during password authentication with host "
                f"{self.host}; Exception: {exc}")
            raise exc

    def _channel_open_driver(self) -> None:
        """
        Open channel

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            N/A  # noqa

        """
        self.channel = self.session.open_session()
        self.channel.get_pty()
        logging.debug(f"Channel to host {self.host} opened")

    def _channel_invoke_shell(self) -> None:
        """
        Invoke shell on channel

        Additionally, this "re-points" some ssh2net method calls to the appropriate paramiko
        methods. This happens as ssh2net is primarily built on "ssh2-python" and there is not
        full parity between paramiko/ssh2-python.

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            N/A  # noqa

        """
        self._shell = True
        self.channel.invoke_shell()
        self.channel.read = self._paramiko_read_channel
        self.channel.write = self.channel.sendall
        self.session.set_blocking = self._set_blocking
        self.channel.flush = self._flush

    def _paramiko_read_channel(self):
        """
        Patch channel.read method for paramiko driver

        "ssh2-python" returns a tuple of bytes and data, "paramiko" simply returns the data
        from the channel, patch this for parity with "ssh2-python".

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            N/A  # noqa

        """
        channel_read = self.channel.recv(1024)
        return None, channel_read

    def _flush(self):
        """
        Patch a "flush" method for paramiko driver

        Need to investigate this further for two things:
            1) is "flush" even necessary when using ssh2-python driver?
            2) if it is necessary, is there a combination of reads/writes that would implement
                this in a sane fashion for paramiko

        Args:
            N/A  # noqa

        Returns:
            N/A  # noqa

        Raises:
            N/A  # noqa

        """
        while True:
            time.sleep(0.01)
            if self.channel.recv_ready():
                self._paramiko_read_channel()
            else:
                self.channel.write("\n")
                return

    def _set_blocking(self, blocking):
        # Add docstring
        # need to reset timeout because it seems paramiko sets it to 0 if you set to non blocking
        # paramiko uses seconds instead of ms
        self.channel.setblocking(blocking)
        self.channel.settimeout(self.session_timeout / 1000)

    def _set_timeout(self, timeout):
        # paramiko uses seconds instead of ms
        self.channel.settimeout(timeout / 1000)
Ejemplo n.º 50
0
def trans():
    """
    Create `LoopSocket`-based server/client `Transport`s, yielding the latter.

    Uses `NullServer` under the hood.
    """
    # NOTE: based on the setup/teardown/start_server/verify_finished methods
    # found in ye olde test_auth.py

    # "Network" setup
    socks = LoopSocket()
    sockc = LoopSocket()
    sockc.link(socks)
    tc = Transport(sockc)
    ts = Transport(socks)

    # Start up the in-memory server
    host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
    ts.add_server_key(host_key)
    event = threading.Event()
    server = NullServer()
    ts.start_server(event, server)

    # Tests frequently need to call Transport.connect on the client side, etc
    yield tc

    # Close things down
    tc.close()
    ts.close()
    socks.close()
    sockc.close()
Ejemplo n.º 51
0
 def __connect(self):
     self.t = Transport((self.config['hostname'], self.config['port']))
     self.t.connect(username=self.config['user'], pkey=self.auth_key)
     self.client = SFTPClient.from_transport(self.t)
     self.client.chdir(self.path)
Ejemplo n.º 52
0
    def upload_report(self, filepath, file_name, report_date):
        """
        >>> a = CSVgenerator('123')
        >>> a.upload_report(None, 'upload.txt', '2014-05-06')
        """
        # appropriate changes required for doc test to run - dev only.
        # test_file_path = join(settings.REPORTS_ROOT, '16b6a354-ff32-4be4-b648-fe51fc5b1508.csv')
        # # trace.info('----- {}'.format(abspath(test_file_path)))
        # absolute_filename = abspath(test_file_path)
        # test_filename = '16b6a354-ff32-4be4-b648-fe51fc5b1508.csv'

        #TODO remove after QA testing.
        if not filepath:
            test_file_path = join(settings.REPORTS_ROOT, file_name)
            absolute_filename = abspath(test_file_path)
        else:
            absolute_filename = abspath(filepath)

        try:
            report_date = report_date if isinstance(
                report_date, str) else report_date.strftime(
                    settings.DATE_FORMAT_YMD)
        except Exception:
            log.exception('')

        try:
            #TODO use sftp credentials, once available.
            #TODO get client upload location for per-client report.
            # ftp_client_dir = client.ftp_client_dir
            # if ftp_client_dir == '':
            #     ftp_client_dir = getattr(settings, 'DEFAULT_SFTP_LOCATION', 'default')
            #     logger.exception(u'No FTP configuration for client {} using default value.'.format(client.name))
            #

            year_folder, month_folder, _ = report_date.split('-')

            base_folder_path = self._sftp_settings.get('path', '')

            base_path = base_folder_path.split('/')
            base_folders = [i for i in base_path[1:-1]]
            base_folder = '/' + join(*base_folders)
            env_folder = str(base_path[-1:][0])

            from paramiko import Transport, SFTPClient
            try:
                log.info(u'SFTP logging on to {0} as {1}'.format(
                    settings.SFTP_SERVER, settings.SFTP_USERNAME))
                transport = Transport((self._sftp_settings.get('server', ''),
                                       self._sftp_settings.get('port', 22)))
                transport.connect(
                    username=self._sftp_settings.get('username', ''),
                    password=self._sftp_settings.get('password', ''))
                sftp = SFTPClient.from_transport(transport)
                log.info(u'SFTP dir {0}/{1}/{2}/{3}'.format(
                    base_folder, env_folder, year_folder, month_folder))
                try:
                    sftp.chdir(base_folder)
                except Exception:
                    log.debug('Unable to change to base folder on ftp server.')

                self._make_or_change_sftp_dir(sftp, env_folder)
                self._make_or_change_sftp_dir(sftp, year_folder)
                self._make_or_change_sftp_dir(sftp, month_folder)

                log.debug(u'SFTP uploading {0}'.format(filepath))
                sftp.put(absolute_filename, file_name)
            except Exception:
                log.exception(
                    u'Unrecoverable exception during SFTP upload process.')

            finally:
                log.debug(u'SFTP logging off')

                try:
                    sftp.close()
                except Exception:
                    log.exception(
                        u'SFTP exception while closing SFTP session.')

                try:
                    transport.close()
                except Exception:
                    log.exception(
                        u'SFTP exception while closing SSH connection.')

        except Exception:
            log.debug('Error while uploading report.')
Ejemplo n.º 53
0
class AuthTest(unittest.TestCase):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def start_server(self):
        host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
        self.public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        self.event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not self.event.is_set())
        self.ts.start_server(self.event, self.server)

    def verify_finished(self):
        self.event.wait(1.0)
        self.assertTrue(self.event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_bad_auth_type(self):
        """
        verify that we get the right exception when an unsupported auth
        type is requested.
        """
        self.start_server()
        try:
            self.tc.connect(
                hostkey=self.public_host_key,
                username="******",
                password="******",
            )
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertEqual(BadAuthenticationType, etype)
            self.assertEqual(["publickey"], evalue.allowed_types)

    def test_bad_password(self):
        """
        verify that a bad password gets the right exception, and that a retry
        with the right password works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            self.tc.auth_password(username="******", password="******")
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
        self.tc.auth_password(username="******", password="******")
        self.verify_finished()

    def test_multipart_auth(self):
        """
        verify that multipart auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password(username="******",
                                       password="******")
        self.assertEqual(["publickey"], remain)
        key = DSSKey.from_private_key_file(_support("test_dss.key"))
        remain = self.tc.auth_publickey(username="******", key=key)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_interactive_auth(self):
        """
        verify keyboard-interactive auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)

        def handler(title, instructions, prompts):
            self.got_title = title
            self.got_instructions = instructions
            self.got_prompts = prompts
            return ["cat"]

        remain = self.tc.auth_interactive("commie", handler)
        self.assertEqual(self.got_title, "password")
        self.assertEqual(self.got_prompts, [("Password", False)])
        self.assertEqual([], remain)
        self.verify_finished()

    def test_interactive_auth_fallback(self):
        """
        verify that a password auth attempt will fallback to "interactive"
        if password auth isn't supported but interactive is.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password("commie", "cat")
        self.assertEqual([], remain)
        self.verify_finished()

    def test_auth_utf8(self):
        """
        verify that utf-8 encoding happens in authentication.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password("utf8", _pwd)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_auth_non_utf8(self):
        """
        verify that non-utf-8 encoded passwords can be used for broken
        servers.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password("non-utf8", "\xff")
        self.assertEqual([], remain)
        self.verify_finished()

    def test_auth_gets_disconnected(self):
        """
        verify that we catch a server disconnecting during auth, and report
        it as an auth failure.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            remain = self.tc.auth_password("bad-server", "hello")
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))

    @slow
    def test_auth_non_responsive(self):
        """
        verify that authentication times out if server takes to long to
        respond (or never responds).
        """
        self.tc.auth_timeout = 1  # 1 second, to speed up test
        self.start_server()
        self.tc.connect()
        try:
            remain = self.tc.auth_password("unresponsive-server", "hello")
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
            self.assertTrue("Authentication timeout" in str(evalue))
Ejemplo n.º 54
0
def sftp_client_from_transport(hostname, username, password):
    from paramiko import Transport
    tn = Transport((hostname, 22))
    tn.connect(username=username, password=password)
    return tn.open_sftp_client()
Ejemplo n.º 55
0
class IrmaSFTP(IrmaFTP):
    """Irma SFTP handler

    This class handles the connection with a sftp server
    functions for interacting with it.
    """

    # ==================================
    #  Constructor and Destructor stuff
    # ==================================
    def __init__(self,
                 host,
                 port,
                 auth,
                 key_path,
                 user,
                 passwd,
                 dst_user=None,
                 upload_path='uploads'):
        super(IrmaSFTP, self).__init__(host, port, auth, key_path, user,
                                       passwd, dst_user, upload_path)
        self._client = None
        self._connect()

    # =================
    #  Private methods
    # =================

    def _connect(self):
        if self._conn is not None:
            log.warn("Already connected to sftp server")
            return
        try:
            self._conn = Transport((self._host, self._port))
            self._conn.window_size = pow(2, 27)
            self._conn.packetizer.REKEY_BYTES = pow(2, 32)
            self._conn.packetizer.REKEY_PACKETS = pow(2, 32)
            if self._auth == 'key':
                pkey = RSAKey.from_private_key_file(self._key_path)
                self._conn.connect(username=self._user, pkey=pkey)
            else:
                self._conn.connect(username=self._user, password=self._passwd)

            self._client = SFTPClient.from_transport(self._conn)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    # ================
    #  Public methods
    # ================

    def mkdir(self, path):
        try:
            dst_path = self._get_realpath(path)
            self._client.mkdir(dst_path)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))
        return

    def list(self, path):
        """ list remote directory <path>"""
        try:
            dst_path = self._get_realpath(path)
            return self._client.listdir(dst_path)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def upload_fobj(self, path, fobj):
        """ Upload <data> to remote directory <path>"""
        try:
            dstname = self._hash(fobj)
            path = self._tweaked_join(path, dstname)
            dstpath = self._get_realpath(path)
            self._client.putfo(fobj, dstpath)
            return dstname
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def download_fobj(self, path, remotename, fobj):
        """ returns <remotename> found in <path>"""
        # self._client.getfo(fobj, dstpath)
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, remotename)
            self._client.getfo(full_dstpath, fobj)
            # remotename is hashvalue of data
            self._check_hash(remotename, fobj)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def delete(self, path, filename):
        """ Delete <filename> into directory <path>"""
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, filename)
            self._client.remove(full_dstpath)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def deletepath(self, path, deleteParent=False):
        # recursively delete all subdirs and files
        try:
            for f in self.list(path):
                if self.is_file(path, f):
                    self.delete(path, f)
                else:
                    self.deletepath(self._tweaked_join(path, f),
                                    deleteParent=True)
            if deleteParent:
                dstpath = self._get_realpath(path)
                self._client.rmdir(dstpath)
        except Exception as e:
            reason = "{0} [{1}]".format(str(e), path)
            raise IrmaSFTPError(reason)

    def is_file(self, path, filename):
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, filename)
            st = self._client.stat(full_dstpath)
            return not stat.S_ISDIR(st.st_mode)
        except Exception as e:
            reason = "{0} [{1}]".format(e, path)
            raise IrmaSFTPError(reason)

    def rename(self, oldpath, newpath):
        try:
            old_realpath = self._get_realpath(oldpath)
            new_realpath = self._get_realpath(newpath)
            self._client.rename(old_realpath, new_realpath)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))
Ejemplo n.º 56
0
class Session(BaseSession):

    CIPHERS = None

    @classmethod
    @typechecked
    def parser_arguments(cls) -> None:
        plugin_group = cls.parser().add_argument_group(cls.__name__)
        plugin_group.add_argument('--session-log-dir',
                                  dest='session_log_dir',
                                  help='directory to store ssh session logs')

    @typechecked
    def __init__(
        self, proxyserver: 'ssh_proxy_server.server.SSHProxyServer',
        client_socket: socket.socket, client_address: Union[Tuple[Text, int],
                                                            Tuple[Text, int,
                                                                  int, int]],
        authenticator: Type['ssh_proxy_server.authentication.Authenticator'],
        remoteaddr: Union[Tuple[Text, int], Tuple[Text, int, int,
                                                  int]]) -> None:
        super().__init__()
        self.sessionid = uuid4()
        logging.info(
            f"{EMOJI['information']} session {stylize(self.sessionid, fg('light_blue') + attr('bold'))} created"
        )
        self._transport: Optional[paramiko.Transport] = None

        self.channel = None

        self.proxyserver: 'ssh_proxy_server.server.SSHProxyServer' = proxyserver
        self.client_socket = client_socket
        self.client_address = client_address
        self.name = f"{client_address}->{remoteaddr}"
        self.closed = False

        self.agent_requested: threading.Event = threading.Event()

        self.ssh_requested: bool = False
        self.ssh_channel: Optional[paramiko.Channel] = None
        self.ssh_client: Optional[
            ssh_proxy_server.clients.ssh.SSHClient] = None
        self.ssh_pty_kwargs = None

        self.scp_requested: bool = False
        self.scp_channel = None
        self.scp_command: bytes = b''

        self.sftp_requested: bool = False
        self.sftp_channel = None
        self.sftp_client: Optional[
            ssh_proxy_server.clients.sftp.SFTPClient] = None
        self.sftp_client_ready = threading.Event()

        self.username: str = ''
        self.username_provided: Optional[str] = None
        self.password: Optional[str] = None
        self.password_provided: Optional[str] = None
        self.socket_remote_address = remoteaddr
        self.remote_address: Tuple[Optional[Text],
                                   Optional[int]] = (None, None)
        self.remote_key: Optional[PKey] = None
        self.accepted_key: Optional[PKey] = None
        self.agent: Optional[AgentProxy] = None
        self.authenticator: 'ssh_proxy_server.authentication.Authenticator' = authenticator(
            self)

        self.env_requests: Dict[bytes, bytes] = {}
        self.session_log_dir: Optional[str] = self.get_session_log_dir()

    @typechecked
    def get_session_log_dir(self) -> Optional[str]:
        if not self.args.session_log_dir:
            return None
        session_log_dir = os.path.expanduser(self.args.session_log_dir)
        return os.path.join(session_log_dir, str(self.sessionid))

    @property
    def running(self) -> bool:
        session_channel_open: bool = True
        ssh_channel_open: bool = False
        scp_channel_open: bool = False

        if self.channel is not None:
            session_channel_open = not self.channel.closed
        if self.ssh_channel is not None:
            ssh_channel_open = not self.ssh_channel.closed
        if self.scp_channel is not None:
            scp_channel_open = not self.scp_channel.closed if self.scp_channel else False
        open_channel_exists = session_channel_open or ssh_channel_open or scp_channel_open

        return_value = self.proxyserver.running and open_channel_exists and not self.closed
        return return_value

    @property
    def transport(self) -> paramiko.Transport:
        if self._transport is None:
            self._transport = Transport(self.client_socket)
            key_negotiation.handle_key_negotiation(self)
            if self.CIPHERS:
                if not isinstance(self.CIPHERS, tuple):
                    raise ValueError('ciphers must be a tuple')
                self._transport.get_security_options().ciphers = self.CIPHERS
            host_key: Optional[PKey] = self.proxyserver.host_key
            if host_key is not None:
                self._transport.add_server_key(host_key)
            self._transport.set_subsystem_handler(
                'sftp', ProxySFTPServer, self.proxyserver.sftp_interface, self)

        return self._transport

    @typechecked
    def _start_channels(self) -> bool:
        # create client or master channel
        if self.ssh_client:
            self.sftp_client_ready.set()
            return True

        if not self.agent or self.authenticator.REQUEST_AGENT_BREAKIN:
            try:
                if self.agent_requested.wait(
                        1) or self.authenticator.REQUEST_AGENT_BREAKIN:
                    self.agent = AgentProxy(self.transport)
            except ChannelException:
                logging.error(
                    "Breakin not successful! Closing ssh connection to client")
                self.agent = None
                self.close()
                return False
        # Connect method start
        if not self.agent:
            if self.username_provided is None:
                logging.error("No username proviced during login!")
                return False
            return self.authenticator.auth_fallback(
                self.username_provided) == paramiko.common.AUTH_SUCCESSFUL

        if self.authenticator.authenticate(
                store_credentials=False) != paramiko.common.AUTH_SUCCESSFUL:
            if self.username_provided is None:
                logging.error("No username proviced during login!")
                return False
            if self.authenticator.auth_fallback(
                    self.username_provided) == paramiko.common.AUTH_SUCCESSFUL:
                return True
            else:
                self.transport.close()
                return False

        # Connect method end
        if not self.scp_requested and not self.ssh_requested and not self.sftp_requested:
            if self.transport.is_active():
                self.transport.close()
                return False

        self.sftp_client_ready.set()
        return True

    @typechecked
    def start(self) -> bool:
        event = threading.Event()
        self.transport.start_server(
            event=event,
            server=self.proxyserver.authentication_interface(self))

        while not self.channel:
            self.channel = self.transport.accept(0.5)
            if not self.running:
                self.transport.close()
                return False

        if not self.channel:
            logging.error('(%s) session error opening channel!', self)
            self.transport.close()
            return False

        # wait for authentication
        event.wait()

        if not self.transport.is_active():
            return False

        self.proxyserver.client_tunnel_interface.setup(self)

        if not self._start_channels():
            return False

        logging.info(
            f"{EMOJI['information']} {stylize(self.sessionid, fg('light_blue') + attr('bold'))} - session started"
        )
        return True

    @typechecked
    def close(self) -> None:
        if self.agent:
            self.agent.close()
            logging.debug("(%s) session agent cleaned up", self)
        if self.ssh_client:
            logging.debug("(%s) closing ssh client to remote", self)
            if self.ssh_client.transport:
                self.ssh_client.transport.close()
            # With graceful exit the completion_event can be polled to wait, well ..., for completion
            # it can also only be a graceful exit if the ssh client has already been established
            if self.transport.completion_event is not None:
                if self.transport.completion_event.is_set(
                ) and self.transport.is_active():
                    self.transport.completion_event.clear()
                    while self.transport.is_active():
                        if self.transport.completion_event.wait(0.1):
                            break
        if self.transport.server_object is not None:
            for f in cast(BaseServerInterface,
                          self.transport.server_object).forwarders:
                f.close()
                f.join()
        self.transport.close()
        logging.info(
            f"{EMOJI['information']} session {stylize(self.sessionid, fg('light_blue') + attr('bold'))} closed"
        )
        logging.debug(f"({self}) session closed")
        self.closed = True

    @typechecked
    def __str__(self) -> str:
        return self.name

    @typechecked
    def __enter__(self) -> 'Session':
        return self

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        logging.debug("(%s) session exited", self)
        self.close()
Ejemplo n.º 57
0
class MyTSFTPRequestHandler(SocketServer.BaseRequestHandler):
    timeout = 60
    auth_timeout = 60

    def setup(self):
        self.transport = Transport(self.request)
        self.transport.load_server_moduli()
        so = self.transport.get_security_options()
        so.digests = ('hmac-sha1', )
        so.compression = ('*****@*****.**', 'none')
        self.transport.add_server_key(self.server.host_key)
        self.transport.set_subsystem_handler('sftp', MyTSFTPServer,
                                             MyTSFTPServerInterface)

    def handle(self):
        self.transport.start_server(server=MyTServerInterface())

    def handle_timeout(self):
        self.transport.close()
Ejemplo n.º 58
0
 def create_connection(cls, host, port, username, password):
     transport = Transport(sock=(host, port))
     transport.connect(username=username, password=password)
     cls._connection = SFTPClient.from_transport(transport)
Ejemplo n.º 59
0
Archivo: sftp.py Proyecto: yehias/irma
class IrmaSFTP(FTPInterface):
    """Irma SFTP handler

    This class handles the connection with a sftp server
    functions for interacting with it.
    """

    _Exception = IrmaSFTPError

    # ==================================
    #  Constructor and Destructor stuff
    # ==================================

    def __init__(self,
                 host,
                 port,
                 auth,
                 key_path,
                 user,
                 passwd,
                 dst_user=None,
                 upload_path='uploads',
                 hash_check=False,
                 autoconnect=True):
        self._conn = None
        self._client = None
        super().__init__(host, port, auth, key_path, user, passwd, dst_user,
                         upload_path, hash_check, autoconnect)

    def connected(self):
        return self._conn is not None

    # ============================
    #  Overridden private methods
    # ============================

    def _connect(self):
        self._conn = Transport((self._host, self._port))
        self._conn.window_size = pow(2, 27)
        self._conn.packetizer.REKEY_BYTES = pow(2, 32)
        self._conn.packetizer.REKEY_PACKETS = pow(2, 32)
        if self._auth == 'key':
            pkey = RSAKey.from_private_key_file(self._key_path)
            self._conn.connect(username=self._user, pkey=pkey)
        else:
            self._conn.connect(username=self._user, password=self._passwd)

        self._client = SFTPClient.from_transport(self._conn)

    def _disconnect(self, *, force=False):
        self._client = None
        if not force:
            self._conn.close()
        self._conn = None

    def _upload(self, remote, fobj):
        self._client.putfo(fobj, remote)

    def _download(self, remote, fobj):
        self._client.getfo(remote, fobj)

    def _ls(self, remote):
        return self._client.listdir(remote)

    def _is_file(self, remote):
        return not self._is_dir(remote)

    def _is_dir(self, remote):
        st = self._client.stat(remote)
        return stat.S_ISDIR(st.st_mode)

    def _rm(self, remote):
        self._client.remove(remote)

    def _rmdir(self, remote):
        self._client.rmdir(remote)

    def _mkdir(self, remote):
        self._client.mkdir(remote)

    def _mv(self, oldremote, newremote):
        self._client.rename(oldremote, newremote)