Beispiel #1
1
def main():
    if len(sys.argv) < 2:
        print 'Usage: %s [version] [--install] [--local|username password]' % sys.argv[0]
        print 'Where [version] is the branch you want to checkout'
        print 'and username and password are for your eduforge account'
        print 'Eg. %s 0.7 --local' % sys.argv[0]
    else:
        version = sys.argv[1]
        branch = 'http://exe.cfdl.auckland.ac.nz/svn/exe/branches/%s' % version
        origDir = Path(sys.argv[0]).abspath().dirname()
        tmp = TempDirPath()
        os.chdir(tmp)
        os.system('svn export %s exe' % branch)
        (origDir/'../../exe/webui/firefox').copytree(tmp/'exe/exe/webui/firefox')
        os.chdir(tmp/'exe')
        tarball = Path('../exe-%s-source.tgz' % version).abspath()
        os.system('tar czf %s *' % tarball)
        os.chdir(tmp)
        if '--local' not in sys.argv:
            try:
                from paramiko import Transport
            except ImportError:
                print 'To upload you need to install paramiko python library from:'
                print 'http://www.lag.net/paramiko'
                sys.exit(2)
            from socket import socket, gethostbyname
            s = socket()
            s.connect((gethostbyname('shell.eduforge.org'), 22))
            t = Transport(s)
            t.connect()
            t.auth_password(sys.argv[-2], sys.argv[-1])
            f = t.open_sftp_client()
            f.chdir('/home/pub/exe')
            f.put(tarball.encode('utf8'), tarball.basename().encode('utf8'))
        if os.getuid() == 0:
            tarball.copyfile('/usr/portage/distfiles/' + tarball.basename())
        os.chdir(tmp/'exe/installs/gentoo')
        newEbuildFilename = Path('exe-%s.ebuild' % version).abspath()
        if not newEbuildFilename.exists():
            Path('exe-0.7.ebuild').copy(newEbuildFilename)
        if os.getuid() == 0:
            ebuildDir = Path('/usr/local/portage/dev-python/exe')
            if ebuildDir.exists():
                ebuildDir.rmtree()
            ebuildDir.makedirs()
            os.chdir(ebuildDir)
            newEbuildFilename.copy(ebuildDir)
            filesDir = ebuildDir/'files'
            filesDir.makedirs()
            Path(tmp/'exe/installs/gentoo/all-config.patch').copy(filesDir)
            if '--local' not in sys.argv:
                oldTarball = Path('/usr/portage/distfiles/')/tarball.basename()
                if oldTarball.exists():
                    oldTarball.remove()
                os.environ['GENTOO_MIRRORS']=''
                os.system('ebuild %s fetch' % newEbuildFilename.basename())
            os.system('ebuild %s manifest' % newEbuildFilename.basename())
            os.system('ebuild %s digest' % newEbuildFilename.basename())
            if '--install' in sys.argv:
                os.system('ebuild %s install' % newEbuildFilename.basename())
 def __ensure_connection(self,
                         transport: paramiko.Transport,
                         force: bool = False) -> paramiko.Transport:
     if transport is None or not transport.is_active() or force:
         if transport is not None:
             transport.close()
         transport = paramiko.Transport((self.__host, self.__port))
         logger.info('Connecting to {} on port {}'.format(
             self.__host, self.__port))
         try:
             if isinstance(self.__credential, PasswordCredential):
                 logger.debug('Authenticating using a password')
                 transport.connect(username=self.__credential.username,
                                   password=self.__credential.password)
             elif isinstance(self.__credential, PubKeyCredential):
                 logger.debug('Authenticating using a public key')
                 key = self.__get_key_from_file(
                     self.__credential.public_key,
                     self.__credential.passphrase)
                 transport.connect(username=self.__credential.username,
                                   pkey=key)
             else:
                 raise RuntimeError('Unknown kind of credential')
             logger.info('Connection (re)established')
         except paramiko.SSHException:
             raise ConnectionError(
                 'Cerulean was disconnected and could not reconnect')
     return transport
Beispiel #3
0
def main():
    try:
        from paramiko import Transport
    except ImportError:
        print
        print 'To upload you need to install paramiko python library from:'
        print 'http://www.lag.net/paramiko',
        print 'or on ubuntu go: apt-get install python2.4-paramiko'
        print
        sys.exit(2)
    server = 'shell.eduforge.org'
    basedir = '/home/pub/exe/'
    print 'Please enter password for %s@%s:' % (sys.argv[-1], server)
    password = getpass()
    print 'Renaming files'
    install = Path('eXe_install_windows.exe')
    newName = Path('eXe-install-%s.exe' % release)
    install = renameFile(install, newName)
    ready2run = Path('exes.exe')
    newName = Path('eXe-ready2run-%s.exe' % release)
    ready2run = renameFile(ready2run, newName)
    print 'Uploading'
    print 'connecting to %s...' % server
    from socket import socket, gethostbyname
    s = socket()
    s.connect((gethostbyname(server), 22))
    t = Transport(s)
    t.connect()
    t.auth_password(sys.argv[-1], password)
    sftp = t.open_sftp_client()
    sftp.chdir(basedir)
    upFile(sftp, install)
    upFile(sftp, ready2run)
Beispiel #4
0
    def upload_dir(self, remote_path):
        transport = Transport((self.sonar_host, self.ssh_port))
        transport.connect(username=self.username, password=self.password)

        with SCPClient(transport) as scp:
            scp.put(self.local_project_path,
                    recursive=True,
                    remote_path=remote_path)
            scp.put(os.path.join(self.local_project_path,
                                 self.sonar_properties),
                    remote_path=os.path.join(
                        remote_path,
                        self.local_project_path.split(os.sep)[-1]))

        # execute command
        ssh = SSHClient()
        ssh.set_missing_host_key_policy(AutoAddPolicy())
        ssh.connect(self.sonar_host,
                    username=self.username,
                    password=self.password)
        sh_stdin, ssh_stdout, ssh_stderr = ssh.exec_command(
            'cd {0} && {1}'.format(
                remote_path + '/' + self.local_project_path.split(os.sep)[-1],
                self.sonar_scanner_command))
        ssh_stdout = ssh_stdout.read().decode('utf-8')
        if re.search(self.success_regex, ssh_stdout):
            print('Link to dashboard: http://{0}:{1}/dashboard?id={2}'.format(
                self.sonar_host, self.sonar_ui_port, self.project_key))
        else:
            print('Failed to generate Sonar Qube report. Error:\n{0}'.format(
                ssh_stdout))
Beispiel #5
0
def sftp_server():
    """
    Set up an in-memory SFTP server thread. Yields the client Transport/socket.

    The resulting client Transport (along with all the server components) will
    be the same object throughout the test session; the `sftp` fixture then
    creates new higher level client objects wrapped around the client
    Transport, as necessary.
    """
    # Sockets & transports
    socks = LoopSocket()
    sockc = LoopSocket()
    sockc.link(socks)
    tc = Transport(sockc)
    ts = Transport(socks)
    # Auth
    host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
    ts.add_server_key(host_key)
    # Server setup
    event = threading.Event()
    server = StubServer()
    ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer)
    ts.start_server(event, server)
    # Wait (so client has time to connect? Not sure. Old.)
    event.wait(1.0)
    # Make & yield connection.
    tc.connect(username='******', password='******')
    yield tc
Beispiel #6
0
def remote_scp(ftp_type, host_ip, remote_path, local_path, username, password):
    ssh_port = 22
    try:
        conn = Transport((host_ip, ssh_port))
        conn.connect(username=username, password=password)
        sftp = SFTPClient.from_transport(conn)
        if ftp_type == 'remoteRead':
            print('read')
            if not local_path:
                filename = os.path.split(remote_path)
                local_path = os.path.join('/tmp', filename[-1])
            print('开始从服务器下载文件......')
            sftp.get(remote_path, local_path)
            print(f'文件{filename[-1]}已经下载到本地')

        if ftp_type == "remoteWrite":
            print('write')
            sftp.put(local_path, remote_path)

        conn.close()
        return True
    except IOError as e:
        print('没有找到目录', e)
    except Exception as e:
        print('error!!!', e)
    def write_redirects_to_sftp(self, from_path, to_path, cron):
        try:
            ssh_key_object = RSAKey(filename=app.config['SFTP_SSH_KEY_PATH'],
                                    password=app.config['SFTP_SSH_KEY_PASSPHRASE'])

            remote_server_public_key = HostKeyEntry.from_line(app.config['SFTP_REMOTE_HOST_PUBLIC_KEY']).key
            # This will throw a warning, but the (string, int) tuple will automatically be parsed into a Socket object
            remote_server = Transport((app.config['SFTP_REMOTE_HOST'], 22))
            remote_server.connect(hostkey=remote_server_public_key, username=app.config['SFTP_USERNAME'], pkey=ssh_key_object)

            sftp = SFTPClient.from_transport(remote_server)
            sftp.put(from_path, to_path)
            if cron:
                return 'SFTP publish from %s to %s succeeded' % (from_path, to_path)
            else:
                return fjson.dumps({
                    'type': 'success',
                    'message': 'Redirect updates successful'
                })
        except:
            if cron:
                return 'SFTP publish from %s to %s failed' % (from_path, to_path)
            else:
                return fjson.dumps({
                    'type': 'danger',
                    'message': 'Redirect updates failed'
                })
Beispiel #8
0
def sftp_server():
    """
    Set up an in-memory SFTP server thread. Yields the client Transport/socket.

    The resulting client Transport (along with all the server components) will
    be the same object throughout the test session; the `sftp` fixture then
    creates new higher level client objects wrapped around the client
    Transport, as necessary.
    """
    # Sockets & transports
    socks = LoopSocket()
    sockc = LoopSocket()
    sockc.link(socks)
    tc = Transport(sockc)
    ts = Transport(socks)
    # Auth
    host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
    ts.add_server_key(host_key)
    # Server setup
    event = threading.Event()
    server = StubServer()
    ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer)
    ts.start_server(event, server)
    # Wait (so client has time to connect? Not sure. Old.)
    event.wait(1.0)
    # Make & yield connection.
    tc.connect(username="******", password="******")
    yield tc
Beispiel #9
0
class SshConnection:
    """ tsc ssh connection """

    def __init__(self, name, host, port=22, user=None, password=None):
        """ init """

        self.connection = None
        self.session = None
        self.host = host
        self.port = port
        self.loginAccount = {
            'user': user,
            'pass': password
        }
        self.name = name
        self.timeout = 4

        self.nbytes = 4096

        try:
            self.connection =  Transport((self.host, self.port))
            self.connection.connect(username=self.loginAccount['user'],
                                    password=self.loginAccount['pass'])
            self.connection.set_keepalive(self.timeout)
            self.session = self.connection.open_channel(kind='session')

            print('ssh connection created!')
        except Exception, e:
            self.__del__()
            print('ssh connection failed: {er}'.format(er=e))
Beispiel #10
0
class SFTPCopy:
    def __init__(self, host, port, username, password):
        self.host = host
        self.port = port
        self.username = username
        self.password = password
        self.t = Transport((self.host, self.port))
        self.t.connect(username=self.username, password=self.password)   # 注意:此处需要用默认传参方式传参
        self.sftp = SFTPClient.from_transport(self.t)

    def copyData(self, localdir, remotedir):

        for (root, dirs, files) in walk(localdir):
            for dir in dirs:
                dest = path.join(remotedir, root[len(localdir):][1:], dir)
                try:
                    self.sftp.mkdir(dest)
                except Exception as e:
                    print('ERROR: make directory %s %s' % (dest, e))
            for file in files:
                src = path.join(root, file)
                dest = path.join(remotedir, root[len(localdir):][1:], file)
                try:
                    self.sftp.put(src, dest)
                except Exception as e:
                    print('ERROR: touch file %s %s' % (dest, e))
Beispiel #11
0
class SFTPUploader(object):
    """Uploads files to (s)ftp"""
    def __init__(self, sftp_settings):
        self.transport = Transport(
            (sftp_settings['HOST'], int(sftp_settings['PORT'])))
        self.transport.connect(username=sftp_settings['USER'],
                               password=sftp_settings['PASSWORD'])
        self.connection = SFTPClient.from_transport(self.transport)

        logger.debug(
            "SFTPUploader initiated. Sending files to {host}:{port}".format(
                host=sftp_settings['HOST'], port=sftp_settings['PORT']))

    def __del__(self):
        try:
            self.connection.close()
            self.transport.close()
        except AttributeError:
            pass

        logger.debug("SFTPUploader session completed. Connection closed.")

    def upload_file(self, local_filepath, filename):
        logger.debug("SFTPUploader: Uploading file {filepath}".format(
            filepath=local_filepath))
        try:
            self.connection.remove(path='./{filename}'.format(
                filename=filename))
        except IOError:
            pass
        self.connection.put(
            localpath=local_filepath,
            remotepath='./{filename}'.format(filename=filename))
Beispiel #12
0
def sftpclient(sftpserver):
    transport = Transport((sftpserver.host, sftpserver.port))
    transport.connect(username="******", password="******")
    sftpclient = SFTPClient.from_transport(transport)
    yield sftpclient
    sftpclient.close()
    transport.close()
Beispiel #13
0
class Client:
    def __init__(self, host, mode):
        self.__mode = mode
        self.__username, self.__address = list(
            map(lambda x: x if ':' not in x else (x.split(':')[0], int(x.split(':')[1])), host.split('@')))
        self.__password = None
        self.__session = None

        self.__transport = Transport(create_connection(self.__address))

    def __authenticate(self):
        self.__password = input(self.__username + '@' + ':'.join([str(x) for x in self.__address]) + '\'s password: '******'vt100', width=10, height=10)
        self.__session.invoke_shell()

    def __handler(self):
        self.__session = self.__transport.open_session()
        self.__command()

    def run(self):
        self.__authenticate()
def mx(src: str):
    spl = src.split(':')
    if len(spl) != 2:
        raise BaseException('Invalid url')
    map_id = spl[1]
    print('Downloading from MX: {}'.format(map_id))
    with urlopen('https://tm.mania-exchange.com/tracks/download/{}'.format(
            map_id)) as res:
        _, params = cgi.parse_header(res.headers.get('Content-Disposition',
                                                     ''))
        filename = params.get('filename', '{}.Map.Gbx'.format(map_id))
        data = res.read()
    t = Transport((settings['host'], settings['port']))
    print('Connecting')
    t.connect(username=settings['user'],
              pkey=paramiko.rsakey.RSAKey.from_private_key_file(
                  settings['pkey']))
    print('Connected')
    client = sftp.SFTPClient.from_transport(t)
    print('Uploading file')
    dst = settings['dest']
    with client.open('{}/{}'.format(dst, filename), 'wb') as file:
        file.write(data)
    print('Done')
    client.close()
    t.close()
class FtpClient:
    """ Методы FTP подключения к удаленному серверу с opencart """

    HOST = '10.0.1.32'
    USERNAME = '******'
    PASSWORD = '******'
    PORT = 22

    remote_path = '/var/log/apache2/'
    remote_file = '/var/log/apache2/access.log.1'
    local_path = 'logs/access.log'

    def __init__(self):
        self.transport = Transport(sock=(self.HOST, self.PORT))
        self.transport.connect(username=self.USERNAME, password=self.PASSWORD)
        self.ftp_connection = SFTPClient.from_transport(self.transport)

    def open_ftp_connection(self):
        """ Метод FTP подключения к удаленному серверу с opencart """

        print(self.ftp_connection.listdir(path=self.remote_path))
        return self.ftp_connection

    def close_ftp_connection(self):
        """ Метод закрытия FTP подключения к удаленному серверу с opencart """
        self.ftp_connection.close()
        return self.ftp_connection

    def download(self):
        """ Метод скачивания логов по FTP подключению к удаленному серверу с opencart """

        self.ftp_connection.get(remotepath=self.remote_file,
                                localpath=self.local_path,
                                callback=None)
 def ParamikoMethod(self):
     trans = Transport((self.ip, 22))
     trans.connect(username=self.username, password=self.password)
     paramiko_client = SSHClient()
     paramiko_client.set_missing_host_key_policy(AutoAddPolicy())
     paramiko_client._transport = trans
     return paramiko_client
Beispiel #17
0
def upload_file(host, port, usr, psw, local_path, remote_path):
    file_count = 0
    print('-' * 50 + '\n')
    print('Start uploading files')
    transport = Transport((host, port))
    transport.connect(username=usr, password=psw)
    sftp = SFTPClient.from_transport(transport)
    for file_name in listdir(local_path):
        local_file = local_path + '\\' + file_name
        remote_file = remote_path + '/' + file_name
        if file_name.split('.')[-1] == 'css':
            new_css = cssmini(open(local_file, 'r').read())
            with open('tmp.css', 'w') as tmp_css:
                tmp_css.write(new_css)
                tmp_css.close()
                sftp.put(cur_path + tmp_css.name, remote_file)
                remove(cur_path + tmp_css.name)
        elif file_name.split('.')[-1] == 'js':
            new_js = jsmini(open(local_file, 'r', encoding='utf-8').read())
            with open('tmp.js', 'w', encoding='utf-8') as tmp_js:
                tmp_js.write(new_js)
                tmp_js.close()
                sftp.put(cur_path + tmp_js.name, remote_file)
                remove(cur_path + tmp_js.name)
        else:
            sftp.put(local_file, remote_file)
        print(file_name + 'Upload completed')
        file_count += 1
    transport.close()
    print('All files have been uploaded, Connection has been closed, Number of files:{}'.format(file_count))
    print('-' * 50 + '\n')
 def scp(self, src_file, dest_path):
     transport = Transport((self.host, int(self.port)))
     transport.connect(username=self.user, password=self.passwd)
     sftp = SFTPClient.from_transport(transport)
     try:
         sftp.put(src_file, dest_path)
     except IOError as e:
         raise e
Beispiel #19
0
 def scp(self, srcFile, destPath):
     transport = Transport((self.host, int(self.port)))
     transport.connect(username=self.user, password=self.passwd)
     sftp = SFTPClient.from_transport(transport)
     try:
         sftp.put(srcFile, destPath)
     except IOError as e:
         raise e
Beispiel #20
0
    def __init__(self, host, username, password):
        try: 
            logging.basicConfig(format='%(levelname)s - %(message)s', level=logging.INFO)

            transport = Transport(sock=(host))
            transport.connect(username=username, password=password)
            self.connection = SFTPClient.from_transport(transport)
        except:
             logging.error('Can not able to conenct to server')
Beispiel #21
0
 def __init__(self, user, passwd, ip):
     self.ip = ip
     try:
         t = Transport((ip, 22))
         t.connect(username=user, password=passwd)
         self.sftpObject = SFTPClient.from_transport(t)
         self.status = 'Success'
     except:
         self.status = 'Failed'
Beispiel #22
0
 def sftp_authentication(self):
     try:
         transport = Transport((ip_server, 222))
         privatekeyfile = os.path.expanduser('./priv_key')
         mykey = RSAKey.from_private_key_file(privatekeyfile)
         transport.connect(username='******', pkey=mykey)
         sftp_client = SFTPClient.from_transport(transport)
         return sftp_client, transport
     except Exception, e:
         print "SFTP Authentication Fail:" + str(e)
Beispiel #23
0
 def transport(self):
     transport = Transport((self.host, self.port))
     transport.connect(username=self.username, password=self.password)
     sftp = SFTPClient.from_transport(transport)
     sftp.put(f'{self.project_url}/web/caweb/html.zip',
              f'{self.nginx_url}/html.zip')
     sftp.put(f'{self.project_url}/server/caserver.zip',
              f'{self.nginx_url}/server/new_caserver.zip')
     print("transport finished!")
     transport.close()
Beispiel #24
0
def scp_remote_to_local(remote_file_path, local_file_path, ip, username,
                        password):
    """Copies a file from remote to local"""
    if not username:
        username = DEFAULT_CVM_USERNAME
    if not password:
        password = DEFAULT_CVM_PASSWD
    transport = Transport((ip, 22))
    transport.connect(username=username, password=password)
    sftp = SFTPClient.from_transport(transport)
    sftp.get(remote_file_path, local_file_path)
Beispiel #25
0
    def __init__(self, config):
        try:
            transport = Transport((config.SFTP_HOST, config.SFTP_PORT))
            transport.connect(username=config.SFTP_USERNAME,
                              password=config.SFTP_PASSWORD)
            self.sftp_client = SFTPClient.from_transport(transport)
        except (SSHException, socket.error) as e:
            raise SFTPError(f'SFTP connection failed due to: {e}.')

        self.path = config.SFTP_PATH
        self.temp_file_path = config.TEMP_FILE_PATH
Beispiel #26
0
 def sftpOpenConnection(self, target):
     """ opens an sftp connection to the given target host;
         credentials are taken from /etc/ssh/sftp_passwd """
     from paramiko import Transport, SFTPClient
     from paramiko.util import load_host_keys
     hostname, username, password = self.getCredentials(target)
     hostkeys = load_host_keys('/etc/ssh/ssh_known_hosts')
     hostkeytype, hostkey = hostkeys[hostname].items()[0]
     trans = Transport((hostname, 22))
     trans.connect(username=username, password=password, hostkey=hostkey)
     return SFTPClient.from_transport(trans)
Beispiel #27
0
def test_simple():
    sock = ("192.168.1.2", 22)
    trans = Transport(sock)
    trans.connect(username="******", password="******")
    client = SSHClientSession()
    client._transport = trans
    stdout = client.exec_command('pwd')[1]
    print("output:", stdout.read().decode())
    client.exec_command('cd /home')
    stdout2 = client.exec_command('pwd')[1]
    print("output:", stdout2.read().decode())
    client.close()
Beispiel #28
0
    def _createConnection(self):
        """
        @see: L{_createConnection<datafinder.persistence.common.connection.pool.ConnectionPool._createConnection>}
        """

        try:
            connection = Transport((self._configuration.hostname, constants.DEFAULT_SSH_PORT))
            connection.connect(username=self._configuration.username, password=self._configuration.password)
            return connection.open_sftp_client()
        except (SSHException, socket.error, socket.gaierror), error:
            errorMessage = u"Unable to establish SFTP connection to host '%s'! " \
                           % (self._configuration.hostname) + "\nReason: '%s'" % str(error)
            raise PersistenceError(errorMessage)
Beispiel #29
0
    def _createConnection(self):
        """
        @see: L{_createConnection<datafinder.persistence.common.connection.pool.ConnectionPool._createConnection>}
        """

        try:
            connection = Transport((self._configuration.hostname, DEFAULT_SSH_PORT))
            connection.connect(username=self._configuration.username, password=self._configuration.password)
            return connection
        except (SSHException, socket.error, socket.gaierror), error:
            errorMessage = u"Unable to establish SSH connection to TSM host '%s'! " \
                           % (self._configuration.hostname) + "\nReason: '%s'" % str(error)
            raise PersistenceError(errorMessage)
Beispiel #30
0
def fetch_paypal_report(
    date: str,
    paypal_credentials: dict,
    paypal_report_prefix: str,
    paypal_report_check_column_name: str,
    s3_bucket: str,
    s3_path: str,
    overwrite: bool,
):
    logger = prefect.context.get("logger")
    logger.info("Pulling Paypal report for {}".format(date))

    if not overwrite:
        # If we're not overwriting and the file already exists, raise a skip
        date_path = get_s3_path_for_date(date)
        s3_key = s3_path + date_path

        logger.info("Checking for existence of: {}".format(s3_key))

        existing_file = list_object_keys_from_s3.run(s3_bucket, s3_key)

        if existing_file:
            raise signals.SKIP(
                'File {} already exists and we are not overwriting. Skipping.'.
                format(s3_key))
        else:
            logger.info(
                "File not found, continuing download for {}.".format(date))

    transport = Transport(config.paypal.host, config.paypal.port)
    transport.connect(username=paypal_credentials.get('username'),
                      password=paypal_credentials.get('password'))
    sftp_connection = SFTPClient.from_transport(transport)

    query_date = datetime.datetime.strptime(date, "%Y-%m-%d")
    remote_filename = get_paypal_filename(query_date, paypal_report_prefix,
                                          sftp_connection,
                                          config.paypal.remote_path)

    try:
        if remote_filename:
            sftp_connection.chdir(config.paypal.remote_path)
            check_paypal_report(sftp_connection, remote_filename,
                                paypal_report_check_column_name)
            formatted_report = format_paypal_report(sftp_connection,
                                                    remote_filename, date)
            return date, formatted_report
        else:
            raise Exception("Remote File Not found for date: {0}".format(date))
    finally:
        sftp_connection.close()
Beispiel #31
0
def get_remote_file(ssh_config: SshClientConfig, local_path: Path, remote_path: Path):
    """Use SFTP to copy a file from a remote server to the local server.

    :param ssh_config: Configuration of the SSH session used for SFTP
    :param local_path: Destination path of the file being transferred
    :param remote_path: Source path of the file being transferred
    """
    ssh_transport = Transport((ssh_config.remote_server, int(ssh_config.ssh_port)))
    ssh_key = RSAKey.from_private_key_file("/home/mycroft/.ssh/id_rsa")
    ssh_transport.connect(hostkey=None, username=ssh_config.remote_user, pkey=ssh_key)
    sftp_client = SFTPClient.from_transport(ssh_transport)
    sftp_client.get(str(remote_path), str(local_path))
    sftp_client.close()
    ssh_transport.close()
Beispiel #32
0
    def _send_hat_script(self, dest_directory='/tmp'):
        transport = Transport(self.remote_server, 22)
        transport.connect(username=self.remote_user,
                          password=self.remote_password)
        file_client = SFTPClient.from_transport(transport)

        script_file = NamedTemporaryFile(mode='w', delete=False)
        script_file.write(self.body)
        script_file.close()
        moveable_script_name = script_file.name + '_'

        file_client.put(script_file.name, moveable_script_name)
        file_client.close()
        return moveable_script_name
Beispiel #33
0
 def sftpOpen(self, target):
     """ opens an sftp connection to the given target host;
         credentials are taken from /etc/ssh/sftp_passwd """
     from paramiko import Transport, SFTPClient
     #from paramiko.util import load_host_keys
     print '>>> sftpOpen self.username: '******'             self.password: '******'             self.hostkey: ', self.password
     print '             self.hostname: ', self.hostname
     trans = Transport((self.hostname, 22))
     trans.banner_timeout = 120
     trans.connect(username=self.username,
                   password=self.password,
                   hostkey=self.hostkey)
     return SFTPClient.from_transport(trans)
Beispiel #34
0
def scp_bela(host='bbb'):
    ssh_config = SSHConfig()
    ssh_config_file = os.path.expanduser('~/.ssh/config')
    if os.path.exists(ssh_config_file):
        with open(ssh_config_file) as f:
            ssh_config.parse(f)
    bbb = ssh_config.lookup(host)
    sf = Transport((bbb['hostname'], 22))
    sf.connect(username=bbb['user'])
    sf.auth_none(bbb['user'])
    # progress callback for scp transfer
    # def progress(filename, size, sent, peername):
    #     print("%s:%s %s: %.2f%% \r" % (peername[0], peername[1], filename, float(sent)/float(size)*100))
    # return SCPClient(sf, progress = progress)
    return SCPClient(sf)
Beispiel #35
0
def _get_transport(hostname: str, username: str, path_to_private_key: str):
    global _transport, _transport_lock
    with _transport_lock:
        if _transport is None:
            key = RSAKey.from_private_key_file(path_to_private_key)
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((hostname, 22))

            # noinspection PyTypeChecker
            _transport = Transport(sock)
            _transport.connect(None, username=username, password='', pkey=key)
            _validate_key(hostname, _transport.get_remote_server_key())
            if not _transport.is_authenticated():
                raise SSHAuthenticationError(
                    'Couldn\'t authenticate the ssh transport')
        return _transport
Beispiel #36
0
class CSftpParamiko:
    def __init__(self):
        # set host, user, passwd
        self.hostname = '172.16.3.36'
        self.username = '******'
        self.password = '******'


# 		self.hostname = HOST_NAME
# 		self.username = USER_NAME
# 		self.password = PASSWORD

    def __exit__(self):
        pass

    def fileSave(self, local_file, remote_path):
        # connect to sftp server
        try:
            self.transport = Transport((self.hostname, 22))
            self.sftp = self.transport.connect(username=self.username,
                                               password=self.password)
            self.sftp_client = SFTPClient.from_transport(self.transport)
        except:
            print("[SFTP_P]: connect fail")

        try:
            self.sftp_client.put(local_file, remote_path)
        except:
            print("[SFTP_P]: put fail")

        try:
            self.sftp_client.close()
            self.transport.close()
        except:
            print("[SFTP_P]: close fail")
Beispiel #37
0
class SFTPConnection:
    """
  Handle a SFTP (SSH over FTP) Connection
  """
    def __init__(self, url, user_name, password=None, private_key=None):
        self.url = url
        self.user_name = user_name
        if password and private_key:
            raise SFTPError(
                "Password and private_key cannot be defined simultaneously")
        self.password = password
        self.private_key = private_key

    def connect(self):
        """ Get a handle to a remote connection """
        # Check URL
        schema = urlparse(self.url)
        if schema.scheme == 'sftp':
            self.transport = Transport((schema.hostname, int(schema.port)))
        else:
            raise SFTPError('Not a valid sftp url %s, type is %s' %
                            (self.url, schema.scheme))
        # Add authentication to transport
        try:
            if self.password:
                self.transport.connect(username=self.user_name,
                                       password=self.password)
            elif self.private_key:
                self.transport.connect(username=self.user_name,
                                       pkey=RSAKey.from_private_key(
                                           StringIO(self.private_key)))
            else:
                raise SFTPError("No password or private_key defined")
            # Connect
            self.conn = SFTPClient.from_transport(self.transport)
        except (socket.gaierror, error), msg:
            raise SFTPError(
                str(msg) + ' while establishing connection to %s' %
                (self.url, ))
        # Go to specified directory
        try:
            schema.path.rstrip('/')
            if len(schema.path):
                self.conn.chdir(schema.path)
        except IOError, msg:
            raise SFTPError(
                str(msg) + ' while changing to dir -%r-' % (schema.path, ))
def sftp_client(ip,uname,pw,r_path,l_path):
#	util.log_to_file('/root/download_config/paramiko.log')
	try:
		t=Transport((ip,22))
		t.connect(username=uname,password=pw)
		sftp=SFTPClient.from_transport(t)
		sftp.get(r_path,l_path)
		t.close()
#	ssh = SSHClient()
#	ssh.set_missing_host_key_policy(AutoAddPolicy())
#	ssh.connect(ip,22,uname,pw,look_for_keys=False,allow_agent=False)
#	t=ssh.get_transport()
#	sftp=SFTPClient.from_transport(t)
#	sftp.get(r_path,l_path)
#	t.close()
	except:
		print 'sftp %s failed' % (ip)
Beispiel #39
0
def main():
    t = Transport(('127.0.0.1', 2222))
    t.connect(None, 'foo', 'pass')

    sftp = SFTPClient.from_transport(t)
    loc = 'upload'
    items_1 = []
    items_2 = []
    files = []
    for item in sftp.listdir(loc):
        items_1.append(item)
    for item in sftp.listdir_iter(loc):
        items_2.append(item.filename)
        if stat.S_IFMT(item.st_mode) != stat.S_IFDIR:
            files.append(item)
    assert items_1 == items_2
    print(f'{len(items_2)} items listed but only {len(files)} files')
Beispiel #40
0
def listdir(hostname, path="/var/tmp", filter="", port=1035, username="", password=""):
    """
        paramiko sftp listdir wrapper, with option to filter files
    """
    # Paramiko client configuration

    t = Transport((hostname, port))
    t.connect(username=username, password=password)
    sftp = SFTPClient.from_transport(t)

    try:
        rex = re.compile(filter)
    except:
        print "Invalid regular expression: " + filter
        sys.exit(1)

    return [x for x in sftp.listdir(path) if rex.match(x)]
Beispiel #41
0
    def upload_report_to_sftp(self, client_id, report_date, absolute_filename):
        """
        Upload the given file, using SFTP, to the configured FTP server. The
        file should be uploaded to the appropriate directory for the specified
        client and the date of the report.
        """
        try:
            client = Clients.objects.get(id=client_id)
        except Clients.DoesNotExist:
            logger.exception(u'No configuration for client {0}.'.format(client_id))
            raise

        filename = basename(absolute_filename)
        base_folder, env_folder, year_folder, month_folder = self._get_sftp_dirs(client, report_date)

        try:
            logger.debug(u'SFTP logging on to {0} as {1}'.format(settings.SFTP_SERVER, settings.SFTP_USERNAME))
            transport = Transport((settings.SFTP_SERVER, settings.SFTP_PORT))
            transport.connect(username=settings.SFTP_USERNAME, password=settings.SFTP_PASSWORD)
            sftp = SFTPClient.from_transport(transport)

            logger.debug(u'SFTP dir {0}/{1}/{2}/{3}'.format(base_folder, env_folder, year_folder, month_folder))
            sftp.chdir(base_folder)
            self._make_or_change_sftp_dir(sftp, env_folder)
            self._make_or_change_sftp_dir(sftp, year_folder)
            self._make_or_change_sftp_dir(sftp, month_folder)

            logger.debug(u'SFTP uploading {0}'.format(filename))
            sftp.put(absolute_filename, filename)
        except Exception:
            logger.exception(u'Unrecoverable exception during SFTP upload process.')
            raise
        finally:
            logger.debug(u'SFTP logging off')

            try:
                sftp.close()
            except Exception:
                logger.exception(u'SFTP exception while closing SFTP session.')

            try:
                transport.close()
            except Exception:
                logger.exception(u'SFTP exception while closing SSH connection.')
Beispiel #42
0
def sftp_connect(config):
    try:
        hostname = config["hostname"]
        port = int(config["port"])
        username = config["username"]
        pkeyfile = config["privatekey_file"]
        pkeypassword = config["privatekey_passphrase"]
        host_keys = load_host_keys(config["hostkeys_file"])
        hostkeytype, hostkey = get_hostkeytype_and_hostkey(host_keys,
                                                           config["hostname"])
        t = Transport((hostname, port))
        pkey = get_privatekey_from_file(pkeyfile, pkeypassword) 
        t.connect(username=username, pkey=pkey, hostkey=hostkey)

        return SFTPClient.from_transport(t)

    except SSHException, e:
        print " [*] " + str(e)
        sys.exit(1)
class SFTPConnection:
  """
  Handle a SFTP (SSH over FTP) Connection
  """

  def __init__(self, url, user_name, password=None, private_key=None):
    self.url = url
    self.user_name = user_name
    if password and private_key:
      raise SFTPError("Password and private_key cannot be defined simultaneously")
    self.password = password
    self.private_key = private_key

  def connect(self):
    """ Get a handle to a remote connection """
    # Check URL
    schema = urlparse(self.url)
    if schema.scheme == 'sftp':
      self.transport = Transport((schema.hostname, int(schema.port)))
    else:
      raise SFTPError('Not a valid sftp url %s, type is %s' %(self.url, schema.scheme))
    # Add authentication to transport
    try:
      if self.password:
        self.transport.connect(username=self.user_name, password=self.password)
      elif self.private_key:
        self.transport.connect(username=self.user_name,
                               pkey=RSAKey.from_private_key(StringIO(self.private_key)))
      else:
        raise SFTPError("No password or private_key defined")
      # Connect
      self.conn = SFTPClient.from_transport(self.transport)
    except (socket.gaierror,error), msg:
      raise SFTPError(str(msg) + ' while establishing connection to %s' % (self.url,))
    # Go to specified directory
    try:
      schema.path.rstrip('/')
      if len(schema.path):
        self.conn.chdir(schema.path)
    except IOError, msg:
      raise SFTPError(str(msg) + ' while changing to dir -%r-' % (schema.path,))
Beispiel #44
0
def get_sftp_conn(config):
    """Make a SFTP connection, returns sftp client and connection objects"""
    remote = config.get('remote_location')
    parts = urlparse(remote)

    if ':' in parts.netloc:
        hostname, port = parts.netloc.split(':')
    else:
        hostname = parts.netloc
        port = 22
    port = int(port)

    username = config.get('remote_username') or getuser()
    luser = get_local_user(username)
    sshdir = get_ssh_dir(config, luser)
    hostkey = get_host_keys(hostname, sshdir)

    try:
        sftp = None
        keys = get_ssh_keys(sshdir)
        transport = Transport((hostname, port))
        while not keys.empty():
            try:
                key = PKey.from_private_key_file(keys.get())
                transport.connect(
                    hostkey=hostkey,
                    username=username,
                    password=None,
                    pkey=key)
                sftp = SFTPClient.from_transport(transport)
                break
            except (PasswordRequiredException, SSHException):
                pass
        if sftp is None:
            raise SaChannelUpdateTransportError("SFTP connection failed")
        return sftp, transport
    except BaseException as msg:
        raise SaChannelUpdateTransportError(msg)
Beispiel #45
0
class RemoteFileManager(object):
    def __init__(self, host, user, password, port=22):
        self._host = host
        self._port = port
        self._user = user
        self._pass = password

    def connect(self):
        self._transport = Transport((self._host, self._port))
        self._transport.connect(username=self._user, password=self._pass)
        self._client = SFTPClient.from_transport(self._transport)

    def files(self, path=None):
        path = '' if path is None else path
        return self._client.listdir(path)

    def file_text(self, path):
        return self._client.open(path, 'r').readlines()

    def lock_wait(self, path, sleep_time=3):
        while True:
            try:
                self._client.open(path, 'r')
                print('Waiting for lock on sums file to be released...')
                sleep(sleep_time)
            except Exception:
                ## No lock file present on remote
                break

    def file_iter(self, path):
        try:
            self._client.stat(path)
        except IOError:
            with self._client.open(path, 'w') as remote_file:
                remote_file.close()

        for line in self._client.open(path, 'r'):
            yield line
Beispiel #46
0
def main():
    try:
        from paramiko import Transport
    except ImportError:
        print
        print "To upload you need to install paramiko python library from:"
        print "http://www.lag.net/paramiko",
        print "or on ubuntu go: apt-get install python2.4-paramiko"
        print
        sys.exit(2)
    # Setup for eduforge
    server = "shell.eduforge.org"
    basedir = "/home/pub/exe/"
    print "Please enter password for %s@%s:" % (sys.argv[-1], server)
    password = getpass()
    # Get the version
    # Rename the files
    print "Renaming files"
    install = Path("eXe_install_windows.exe")
    newName = Path("eXe-install-%s.exe" % release)
    install = renameFile(install, newName)
    ready2run = Path("exes.exe")
    newName = Path("eXe-ready2run-%s.exe" % release)
    ready2run = renameFile(ready2run, newName)
    # Upload
    print "Uploading"
    print "connecting to %s..." % server
    from socket import socket, gethostbyname

    s = socket()
    s.connect((gethostbyname(server), 22))
    t = Transport(s)
    t.connect()
    t.auth_password(sys.argv[-1], password)
    sftp = t.open_sftp_client()
    sftp.chdir(basedir)
    upFile(sftp, install)
    upFile(sftp, ready2run)
Beispiel #47
0
    def __call__(self):
        registry = getUtility(IRegistry)
        recensio_settings = registry.forInterface(IRecensioSettings)
        host = recensio_settings.xml_export_server
        username = recensio_settings.xml_export_username
        password = recensio_settings.xml_export_password
        if not host:
            return 'no host configured'
        log.info("Starting XML export to sftp")

        exporter = getUtility(IFactory, name='chronicon_exporter')()
        export_xml = exporter.get_export_obj(self.context)
        if export_xml is None:
            msg = "Could not get export file object: {0}".format(
                exporter.export_filename)
            log.error(msg)
            return msg

        zipstream = export_xml.getFile()
        try:
            transport = Transport((host, 22))
            transport.connect(username=username, password=password)
            sftp = SFTPClient.from_transport(transport)
            attribs = sftp.putfo(zipstream.getBlob().open(), self.filename)
        except (IOError, SSHException) as ioe:
            msg = "Export failed, {0}: {1}".format(ioe.__class__.__name__, ioe)
            log.error(msg)
            return msg
        if attribs.st_size == zipstream.get_size():
            msg = "Export successful"
            log.info(msg)
            return msg
        else:
            msg = "Export failed, {0}/{1} bytes transferred".format(
                attribs.st_size, zipstream.get_size())
            log.error(msg)
            return msg
Beispiel #48
0
class ssh_wrapper():

    def __init__(self, host, port, user, pw):
        self.transport = Transport((host, port))
        self._datos = {'host': host, 'port': port, 'user': user, 'pw': pw}
        self.rsa_key = None

        self.setPrivateKey(expanduser('~/.ssh/id_rsa'))
        self.conectar()

    def conectar(self):
        if self.rsa_key is None:
            self.transport.connect(username=self._datos['user'],
                                   password=self._datos['pw'])
        else:
            self.transport.connect(username=self._datos['user'],
                                   pkey=self.rsa_key)
        self.transport.set_keepalive(60)

    def setPrivateKey(self, path):
        self.rsa_key = RSAKey.from_private_key_file(path)

    def getCiphers(self):
        return self.transport.get_security_options()._get_ciphers()

    def setCipher(self, cipher):
        self.transport = Transport((self._datos['host'],
                                    self._datos['port']))

        self.transport.get_security_options().ciphers = [cipher, ]

        self.transport.connect(username=self._datos['user'],
                               password=self._datos['pw'])

        self.transport.set_keepalive(60)

    def getSftp(self):
        return SFTPClient.from_transport(self.transport)

    def getSsh(self):
        self.ssh = SSHClient()
        self.ssh.set_missing_host_key_policy(AutoAddPolicy())
Beispiel #49
0
class AuthTest (unittest.TestCase):

    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def start_server(self):
        host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
        self.public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        self.event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not self.event.is_set())
        self.ts.start_server(self.event, self.server)

    def verify_finished(self):
        self.event.wait(1.0)
        self.assertTrue(self.event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_1_bad_auth_type(self):
        """
        verify that we get the right exception when an unsupported auth
        type is requested.
        """
        self.start_server()
        try:
            self.tc.connect(hostkey=self.public_host_key,
                            username='******', password='******')
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertEqual(BadAuthenticationType, etype)
            self.assertEqual(['publickey'], evalue.allowed_types)

    def test_2_bad_password(self):
        """
        verify that a bad password gets the right exception, and that a retry
        with the right password works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            self.tc.auth_password(username='******', password='******')
            self.assertTrue(False)
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
        self.tc.auth_password(username='******', password='******')
        self.verify_finished()

    def test_3_multipart_auth(self):
        """
        verify that multipart auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password(username='******', password='******')
        self.assertEqual(['publickey'], remain)
        key = DSSKey.from_private_key_file(test_path('test_dss.key'))
        remain = self.tc.auth_publickey(username='******', key=key)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_4_interactive_auth(self):
        """
        verify keyboard-interactive auth works.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)

        def handler(title, instructions, prompts):
            self.got_title = title
            self.got_instructions = instructions
            self.got_prompts = prompts
            return ['cat']
        remain = self.tc.auth_interactive('commie', handler)
        self.assertEqual(self.got_title, 'password')
        self.assertEqual(self.got_prompts, [('Password', False)])
        self.assertEqual([], remain)
        self.verify_finished()

    def test_5_interactive_auth_fallback(self):
        """
        verify that a password auth attempt will fallback to "interactive"
        if password auth isn't supported but interactive is.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('commie', 'cat')
        self.assertEqual([], remain)
        self.verify_finished()

    def test_6_auth_utf8(self):
        """
        verify that utf-8 encoding happens in authentication.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('utf8', _pwd)
        self.assertEqual([], remain)
        self.verify_finished()

    def test_7_auth_non_utf8(self):
        """
        verify that non-utf-8 encoded passwords can be used for broken
        servers.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        remain = self.tc.auth_password('non-utf8', '\xff')
        self.assertEqual([], remain)
        self.verify_finished()

    def test_8_auth_gets_disconnected(self):
        """
        verify that we catch a server disconnecting during auth, and report
        it as an auth failure.
        """
        self.start_server()
        self.tc.connect(hostkey=self.public_host_key)
        try:
            remain = self.tc.auth_password('bad-server', 'hello')
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))

    def test_9_auth_non_responsive(self):
        """
        verify that authentication times out if server takes to long to
        respond (or never responds).
        """
        self.tc.auth_timeout = 1  # 1 second, to speed up test
        self.start_server()
        self.tc.connect()
        try:
            remain = self.tc.auth_password('unresponsive-server', 'hello')
        except:
            etype, evalue, etb = sys.exc_info()
            self.assertTrue(issubclass(etype, AuthenticationException))
            self.assertTrue('Authentication timeout' in str(evalue))
Beispiel #50
0
# -*- coding: utf-8 -*- 
__author__ = 'Administrator'


from paramiko import Transport,RSAKey,SFTPClient
import os

host="54.187.91.18"
port=22

private_key_file=os.path.expanduser("~/.ssh/id_rsa")
key=RSAKey.from_private_key_file(private_key_file)


t=Transport((host,port))
t.connect(username="******",pkey=key)

sftp=SFTPClient.from_transport(t)
try:
	sftp.put("/data/websites/test/a.php","/data/websites/")
except Exception,e:
	print(e)

sftp.close()
t.close()

Beispiel #51
0
class SFTPStore(Store):
    """implements the sftp:// storage backend

        configuration via openssh/sftp style urls and
        .ssh/config files

        does not support password authentication or password
        protected authentication keys"""
    def __init__(self, url, **kw):
        if self.netloc.find('@') != -1:
            user, self.netloc = self.netloc.split('@')
        else:
            user = None

        self.config = SSHHostConfig(self.netloc, user)

        host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
        try:
            self.hostkey = list(host_keys[self.config['hostkeyalias']].values())[0]
        except:
            print(str(self.config))
            raise


        if('identityfile' in self.config):
            key_file = os.path.expanduser(self.config['identityfile'])
            #not really nice but i don't see a cleaner way atm...
            try:
                self.auth_key = RSAKey (key_file)
            except SSHException as e:
                if e.message == 'Unable to parse file':
                    self.auth_key = DSAKey (key_file)
                else:
                    raise
        else:
            filename = os.path.expanduser('~/.ssh/id_rsa')
            if os.path.exists(filename):
                self.auth_key = RSAKey(filename)
            else:
                filename = os.path.expanduser('~/.ssh/id_dsa')
                if (os.path.exists(filename)):
                    self.auth_key = DSSKey (filename)

        self.__connect()

    def __connect(self):
        self.t = Transport((self.config['hostname'], self.config['port']))
        self.t.connect(username = self.config['user'], pkey = self.auth_key)
        self.client = SFTPClient.from_transport(self.t)
        self.client.chdir(self.path)

    def __build_fn(self, name):
        return "%s/%s" % (self.path,  name)

    def list(self, type):
        return list(filter(type_patterns[type].match, self.client.listdir(self.path)))

    def get(self, type, name):
        return self.client.open(self.__build_fn(name), mode = 'rb')

    def put(self, type, name, fp):
        remote_file = self.client.open(self.__build_fn(name), mode = 'wb')
        buf = fp.read(4096)
        while (len(buf) > 0):
            remote_file.write(buf)
            buf = fp.read(4096)
        remote_file.close()

    def delete(self, type, name):
        self.client.remove(self.__build_fn(name))

    def stat(self, type, name):
        try:
            stat = self.client.stat(self.__build_fn(name))
            return {'size': stat.st_size}
        except IOError:
            raise NotFoundError

    def close(self):
        """connection has to be explicitly closed, otherwise
            it will hold the process running idefinitly"""
        self.client.close()
        self.t.close()
class TransportTest(ParamikoTest):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self, client_options=None, server_options=None):
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)

        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())

        event = threading.Event()
        self.server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, self.server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEquals(type(o), SecurityOptions)
        self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assert_(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assert_(False)
        except TypeError:
            pass

    def test_2_compute_key(self):
        self.tc.K = long(123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929)
        self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
                          hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.assertEquals(None, self.tc.get_username())
        self.assertEquals(None, self.ts.get_username())
        self.assertEquals(False, self.tc.is_authenticated())
        self.assertEquals(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('slowdive', self.tc.get_username())
        self.assertEquals('slowdive', self.ts.get_username())
        self.assertEquals(True, self.tc.is_authenticated())
        self.assertEquals(True, self.ts.is_authenticated())

    def test_3a_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=str(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        def force_algorithms(options):
            options.ciphers = ('aes256-cbc',)
            options.digests = ('hmac-md5-96',)
        self.setup_test_server(client_options=force_algorithms)
        self.assertEquals('aes256-cbc', self.tc.local_cipher)
        self.assertEquals('aes256-cbc', self.tc.remote_cipher)
        self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
        self.assertEquals(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEquals(None, getattr(self.server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEquals('*****@*****.**', self.server._global_request)

    def test_6_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command('no')
            self.assert_(False)
        except SSHException, x:
            pass

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('', f.readline())
        f = chan.makefile_stderr()
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEquals('Hello there.\n', f.readline())
        self.assertEquals('This is on stderr.\n', f.readline())
        self.assertEquals('', f.readline())
Beispiel #53
0
class SSH(object):
    """Wrapper for Paramiko Transport and Channel for Expect-like sessions"""

    def __init__(self, user, passwd = ""):
        """initialize SSH wrapper but do not connect"""
        self.user = user
        self.passwd = passwd
        self.prompt_re = "([a-z]+@[a-zA-Z0-9\.\-\_]+[>#%])"
        self.reset_timeout_on_newlines = True
        self._timeout_sec = 10
        self.quiet = 0
        self.received_text = ""

    def connect(self, hostname_address):
        """connect to SSH target using paramiko Transport/Channel"""
        self._transport = Transport(hostname_address)
        self._transport.connect(username=self.user, password=self.passwd)
        self.ssh_channel = self._transport.open_channel("session")
        self.pty = self.ssh_channel.get_pty()
        self.ssh_channel.invoke_shell()
        self.wait_for_prompt()
        self.ssh_channel.send("set cli screen-length 0")

    def txrx_status(self):
        """returns send and receive status as string formatted for screen"""
        chan = self.ssh_channel
        return "Send Ready: {}, Receive Ready: {}".format(chan.send_ready(), chan.recv_ready())

    @property
    def timeout(self):
        """timeout_sec getter"""
        return self._timeout_sec

    @timeout.setter
    def timeout(self, new_timeout_sec):
        """timeout_sec setter"""
        self._timeout_sec = new_timeout_sec

    def wait_for_prompt(self):
        """call self.wait_for_regex using self.prompt_re as argument"""
        return self.wait_for_regex(self.prompt_re)

    def wait_for_regex(self, expression, wait_sec=0.1):
        """
            Loop receive and wait until a regular expression is matched in output
             - raise an exception if we hit the timeout
        """
        chan = self.ssh_channel
        done = False
        start_time = time.clock()
        while ((time.clock() - start_time) <= self.timeout) and not done:
            time.sleep(wait_sec)
            if chan.recv_ready():
                while chan.recv_ready():
                    this_receive = chan.recv(1024)
                    self.received_text = self.received_text + this_receive
                    # print to screen if quiet is not set
                    if not self.quiet:
                        stdout.write(this_receive)
                        stdout.flush()
                # reset the start_time if we encounter newlines and self.reset_timeout_on_newlines is True
                if re.search("\n",self.received_text) and self.reset_timeout_on_newlines:
                    start_time = time.clock()
            #else:
                #if not chan.recv_ready, then do nothing
            done = re.search(expression, self.received_text)
        # if done is not True at this point, we consider it to be a timeout action
        if not done:
            raise NetDevError("Timeout waiting for expression match: '{}'".format(expression))
        else:
            return self.received_text

    def _sendline(self, line):
        """send a single line"""
        chan = self.ssh_channel
        if chan.send_ready():
            chan.send(line.strip() + "\n")
        else:
            raise NetDevError("Attempted to send when send not ready")

    def send(self, textblock):
        """work through a textblock and send one line at at time"""
        configtext = textblock.strip()
        for line in textblock.splitlines():
            self._sendline(line)
            self.received_text = ""
            self.wait_for_prompt()
        self.received_text = ""
Beispiel #54
0
    def submitJobToFramework(self, **kwargs):
        jobCommand = 'job'
        
        daemonArgs = DaemonArgs(self.config)
        daemonArgs.command = jobCommand
        unScheduledJob = kwargs['unScheduledJob']
        
        is_fileFeeder = False
        fileFeederUploadedFile = None
        del daemonArgs.param[:]

        # go through all parameters
        for parameter in unScheduledJob.parameters.all():

            # add parameter to daemonArgs.param
            if parameter.service and parameter.param_key and parameter.param_value:

                # check if a file feeder is used
                if parameter.service == settings.FILE_FEEDER_ID:
                    is_fileFeeder = True
                    fileFeederUploadedFile = parameter.param_value

                    remoteFeederFile = os.path.join(self.sftpRemotePath, parameter.param_value)
                    parameterString = '%s.%s=%s' % ( parameter.service, parameter.param_key, remoteFeederFile )
                else:
                    parameterString = '%s.%s=%s' % ( parameter.service, parameter.param_key, parameter.param_value )

                self.logger.debug("add parameter string: %s" % parameterString)
                daemonArgs.param.append([parameterString])

        # in case of a filefeeder upload file to framework server
        if is_fileFeeder:
            self.logger.debug("is file feeder")
            sftp = None
            transport = None
            try:
                transport = Transport((self.sftpHost, self.sftpPort))
                if self.sftpPassword:
                    transport.connect(username=self.sftpUsername, password=self.sftpPassword)
                else:
                    privateKey = None
                    if self.sftpPrivateKeyType and self.sftpPrivateKeyType.lower() == 'rsa':
                        privateKey = RSAKey.from_private_key_file(self.sftpPrivateKey, password=self.sftpPrivateKeyPassword )
                    if self.sftpPrivateKeyType and self.sftpPrivateKeyType.lower() == 'dss':
                        privateKey = DSSKey.from_private_key_file(self.sftpPrivateKey, password=self.sftpPrivateKeyPassword )

                    transport.connect(username=self.sftpUsername, pkey=privateKey)

                sftp = SFTPClient.from_transport(transport)

                filePath = os.path.join( settings.MEDIA_ROOT, fileFeederUploadedFile )
                remotePath = os.path.join( self.sftpRemotePath, fileFeederUploadedFile )

                self.logger.debug("uploading file from %s to %s on remote machine" % (filePath, remotePath))

                sftp.put(filePath, remotePath)
#                            sftp.put(filePath, remotePath, confirm=False)
                sftp.chmod( remotePath, 0644 )

                self.logger.debug("put OK")

            except IOError as e:
                self.logger.error("IOError: %s. Will continue with next scheduled job." % e)
                self.saveJob(Job.FAILED_STATUS, None, unScheduledJob)
            except PasswordRequiredException as e:
                self.logger.error("PasswordRequiredException: %s. Will continue with next scheduled job." % e)
                self.saveJob(Job.FAILED_STATUS, None, unScheduledJob)
            except SSHException as e:
                self.logger.error("SSH Exception: %s. Will continue with next scheduled job." % e)
                self.saveJob(Job.FAILED_STATUS, None, unScheduledJob)
            except Exception as e:
                self.logger.error("Unkown SFTP problem. Will continue with next scheduled job. %s" % e)
                self.saveJob(Job.FAILED_STATUS, None, unScheduledJob)
            finally:
                if sftp is not None:
                    sftp.close()
                if transport is not None:
                    transport.close()
                
        # set job workflow
        daemonArgs.jd_workflow = unScheduledJob.workflow.name

        frameworkJobId = None
        
        try:
            setattr(daemonArgs, jobCommand, 'submit')
            frameworkJobId = self.sendFrameworkCommand(jobCommand, daemonArgs)
            self.saveJob(Job.PROCESSING_STATUS, frameworkJobId, unScheduledJob)
        except WorkflowNotDeployedException:
            # The workflow is not deployed in the framework. To prevent the scheduler retrying continuously
            # we disable this job
            unScheduledJob.status = Schedule.DEACTIVATE_STATUS
            unScheduledJob.save()
        except:
            self.saveJob(Job.FAILED_STATUS, None, unScheduledJob)
        finally:
            daemonArgs.clean(jobCommand)
        
        if unScheduledJob.scheduled_start is not None:
            unScheduledJob.status = Schedule.DEACTIVATED_STATUS
            unScheduledJob.save()
    def connect(self, host, port=22, user=None, passw=None, cert=None, path='/', timeout=10):
        """Method connects to server

        Args:
           host (str): server host
           port (int): server port, default protocol port
           user (str): username
           passw (str): password
           cert (str): path to certificate file
           path (str): server path
           timeout (int): timeout

        Returns:
           bool: result

        Raises:
           event: ftp_before_connect
           event: ftp_after_connect

        """

        try:

            message = '{0}/{1}@{2}:{3}{4} cert:{5}, timeout:{6}'.format(
                user, passw, host, port, path, cert, timeout)
            self._mh.demsg('htk_on_debug_info', self._mh._trn.msg(
                'htk_ftp_connecting', message), self._mh.fromhere())

            ev = event.Event(
                'ftp_before_connect', host, port, user, passw, cert, path, timeout)
            if (self._mh.fire_event(ev) > 0):
                host = ev.argv(0)
                port = ev.argv(1)
                user = ev.argv(2)
                passw = ev.argv(3)
                cert = ev.argv(4)
                path = ev.argv(5)
                timeout = ev.argv(6)

            self._host = host
            self._port = port
            self._user = user
            self._passw = passw
            self._cert = cert

            if (ev.will_run_default()):
                setdefaulttimeout(timeout)
                t = Transport((host, self._port))

                if (user != None or cert != None):
                    pkey = RSAKey.from_private_key_file(
                        self._cert) if (cert != None) else None
                    t.connect(username=user, password=passw, pkey=pkey)
                    self._client = SFTPClient.from_transport(t)

                self._is_connected = True

                if (path != None):
                    self.change_dir(path)

            ev = event.Event('ftp_after_connect')
            self._mh.fire_event(ev)

            return True

        except (SSHException, NoValidConnectionsError, error) as ex:
            self._mh.demsg(
                'htk_on_error', 'error: {0}'.format(ex), self._mh.fromhere())
            return False
Beispiel #56
0
class TransportTest(ParamikoTest):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(self, client_options=None, server_options=None):
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')

        public_host_key = RSAKey(data=bytes(host_key))
        self.ts.add_server_key(host_key)
        
        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())
        
        event = threading.Event()
        self.server = NullServer()
        self.assert_(not event.isSet())
        self.ts.start_server(event, self.server)

        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')

        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEquals(type(o), SecurityOptions)
        self.assert_((b'aes256-cbc', b'blowfish-cbc') != o.ciphers)
        o.ciphers = (b'aes256-cbc', b'blowfish-cbc')
        self.assertEquals((b'aes256-cbc', b'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = (b'aes256-cbc', b'made-up-cipher')
            self.assert_(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assert_(False)
        except TypeError:
            pass
            
    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
        self.tc.H = unhexlify(b'0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key(b'C', 32)
        self.assertEquals(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
                          hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=bytes(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.assertEquals(None, self.tc.get_username())
        self.assertEquals(None, self.ts.get_username())
        self.assertEquals(False, self.tc.is_authenticated())
        self.assertEquals(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        self.assertEquals('slowdive', self.tc.get_username())
        self.assertEquals('slowdive', self.ts.get_username())
        self.assertEquals(True, self.tc.is_authenticated())
        self.assertEquals(True, self.ts.is_authenticated())

    def test_3a_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
        public_host_key = RSAKey(data=bytes(host_key))
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assert_(not event.isSet())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assert_(event.isSet())
        self.assert_(self.ts.is_active())
        
    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        def force_algorithms(options):
            options.ciphers = (b'aes256-cbc',)
            options.digests = (b'hmac-md5-96',)
        self.setup_test_server(client_options=force_algorithms)
        self.assertEquals(b'aes256-cbc', self.tc.local_cipher)
        self.assertEquals(b'aes256-cbc', self.tc.remote_cipher)
        self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
        self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
        
        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEquals(None, getattr(self.server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEquals(b'*****@*****.**', self.server._global_request)
        
    def test_6_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command('no')
            self.assert_(False)
        except SSHException as x:
            pass
        
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send(b'Hello there.\n')
        schan.send_stderr(b'This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEquals(b'Hello there.\n', f.readline())
        self.assertEquals(b'', f.readline())
        f = chan.makefile_stderr()
        self.assertEquals(b'This is on stderr.\n', f.readline())
        self.assertEquals(b'', f.readline())
        
        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send(b'Hello there.\n')
        schan.send_stderr(b'This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)        
        f = chan.makefile()
        self.assertEquals(b'Hello there.\n', f.readline())
        self.assertEquals(b'This is on stderr.\n', f.readline())
        self.assertEquals(b'', f.readline())

    def test_7_invoke_shell(self):
        """
        verify that invoke_shell() does something reasonable.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        chan.send(b'communist j. cat\n')
        f = schan.makefile()
        self.assertEquals(b'communist j. cat\n', f.readline())
        chan.close()
        self.assertEquals(b'', f.readline())

    def test_8_channel_exception(self):
        """
        verify that ChannelException is thrown for a bad open-channel request.
        """
        self.setup_test_server()
        try:
            chan = self.tc.open_channel(b'bogus')
            self.fail('expected exception')
        except ChannelException as x:
            self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)

    def test_9_exit_status(self):
        """
        verify that get_exit_status() works.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        chan.exec_command('yes')
        schan.send(b'Hello there.\n')
        self.assert_(not chan.exit_status_ready())
        # trigger an EOF
        schan.shutdown_read()
        schan.shutdown_write()
        schan.send_exit_status(23)
        schan.close()
        
        f = chan.makefile()
        self.assertEquals(b'Hello there.\n', f.readline())
        self.assertEquals(b'', f.readline())
        count = 0
        while not chan.exit_status_ready():
            time.sleep(0.1)
            count += 1
            if count > 50:
                raise Exception("timeout")
        self.assertEquals(23, chan.recv_exit_status())
        chan.close()

    def test_A_select(self):
        """
        verify that select() on a channel works.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready        
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEquals([], r)
        self.assertEquals([], w)
        self.assertEquals([], e)
        
        schan.send(b'hello\n')
        
        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEquals([chan], r)
        self.assertEquals([], w)
        self.assertEquals([], e)

        self.assertEquals(b'hello\n', chan.recv(6))
        
        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEquals([], r)
        self.assertEquals([], w)
        self.assertEquals([], e)

        schan.close()
        
        # detect eof?
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEquals([chan], r)
        self.assertEquals([], w)
        self.assertEquals([], e)
        self.assertEquals(b'', chan.recv(16))
        
        # make sure the pipe is still open for now...
        p = chan._pipe
        self.assertEquals(False, p._closed)
        chan.close()
        # ...and now is closed.
        self.assertEquals(True, p._closed)
   
    def test_B_renegotiate(self):
        """
        verify that a transport can correctly renegotiate mid-stream.
        """
        self.setup_test_server()
        self.tc.packetizer.REKEY_BYTES = 16384
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        self.assertEquals(self.tc.H, self.tc.session_id)
        for i in range(20):
            chan.send(b'x' * 1024)
        chan.close()
        
        # allow a few seconds for the rekeying to complete
        for i in range(50):
            if self.tc.H != self.tc.session_id:
                break
            time.sleep(0.1)
        self.assertNotEquals(self.tc.H, self.tc.session_id)

        schan.close()

    def test_C_compression(self):
        """
        verify that zlib compression is basically working.
        """
        def force_compression(o):
            o.compression = (b'zlib',)
        self.setup_test_server(force_compression, force_compression)
        chan = self.tc.open_session()
        chan.exec_command(b'yes')
        schan = self.ts.accept(1.0)

        bytes = self.tc.packetizer._Packetizer__sent_bytes
        chan.send(b'x' * 1024)
        bytes2 = self.tc.packetizer._Packetizer__sent_bytes
        # tests show this is actually compressed to *52 bytes*!  including packet overhead!  nice!! :)
        self.assert_(bytes2 - bytes < 1024)
        self.assertEquals(52, bytes2 - bytes)

        chan.close()
        schan.close()

    def test_D_x11(self):
        """
        verify that an x11 port can be requested and opened.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command(b'yes')
        schan = self.ts.accept(1.0)
        
        requested = []
        def handler(c, addr):
            requested.append(addr)
            self.tc._queue_incoming_channel(c)
            
        self.assertEquals(None, getattr(self.server, '_x11_screen_number', None))
        cookie = chan.request_x11(0, single_connection=True, handler=handler)
        self.assertEquals(0, self.server._x11_screen_number)
        self.assertEquals(b'MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
        self.assertEquals(cookie, self.server._x11_auth_cookie)
        self.assertEquals(True, self.server._x11_single_connection)
        
        x11_server = self.ts.open_x11_channel(('localhost', 6093))
        x11_client = self.tc.accept()
        self.assertEquals('localhost', requested[0][0])
        self.assertEquals(6093, requested[0][1])
        
        x11_server.send(b'hello')
        self.assertEquals(b'hello', x11_client.recv(5))
        
        x11_server.close()
        x11_client.close()
        chan.close()
        schan.close()

    def test_E_reverse_port_forwarding(self):
        """
        verify that a client can ask the server to open a reverse port for
        forwarding.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        
        requested = []
        def handler(c, origin_addr, server_addr):
            requested.append(origin_addr)
            requested.append(server_addr)
            self.tc._queue_incoming_channel(c)
            
        port = self.tc.request_port_forward('127.0.0.1', 0, handler)
        self.assertEquals(port, self.server._listen.getsockname()[1])

        cs = socket.socket()
        cs.connect((b'127.0.0.1', port))
        ss, _ = self.server._listen.accept()
        sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername())
        cch = self.tc.accept()
        
        sch.send(b'hello')
        self.assertEquals(b'hello', cch.recv(5))
        sch.close()
        cch.close()
        ss.close()
        cs.close()
        
        # now cancel it.
        self.tc.cancel_port_forward(b'127.0.0.1', port)
        self.assertTrue(self.server._listen is None)

    def test_F_port_forwarding(self):
        """
        verify that a client can forward new connections from a locally-
        forwarded port.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        
        # open a port on the "server" that the client will ask to forward to.
        greeting_server = socket.socket()
        greeting_server.bind(('127.0.0.1', 0))
        greeting_server.listen(1)
        greeting_port = greeting_server.getsockname()[1]

        cs = self.tc.open_channel(b'direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000))
        sch = self.ts.accept(1.0)
        cch = socket.socket()
        cch.connect(self.server._tcpip_dest)
        
        ss, _ = greeting_server.accept()
        ss.send(b'Hello!\n')
        ss.close()
        sch.send(cch.recv(8192))
        sch.close()
        
        self.assertEquals(b'Hello!\n', cs.recv(7))
        cs.close()

    def test_G_stderr_select(self):
        """
        verify that select() on a channel works even if only stderr is
        receiving data.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready        
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEquals([], r)
        self.assertEquals([], w)
        self.assertEquals([], e)
        
        schan.send_stderr(b'hello\n')
        
        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEquals([chan], r)
        self.assertEquals([], w)
        self.assertEquals([], e)

        self.assertEquals(b'hello\n', chan.recv_stderr(6))
        
        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEquals([], r)
        self.assertEquals([], w)
        self.assertEquals([], e)

        schan.close()
        chan.close()

    def test_H_send_ready(self):
        """
        verify that send_ready() indicates when a send would not block.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        self.assertEquals(chan.send_ready(), True)
        total = 0
        K = b'*' * 1024
        while total < 1024 * 1024:
            chan.send(K)
            total += len(K)
            if not chan.send_ready():
                break
        self.assert_(total < 1024 * 1024)

        schan.close()
        chan.close()
        self.assertEquals(chan.send_ready(), True)

    def test_I_rekey_deadlock(self):
        """
        Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent
        
        Note: When this test fails, it may leak threads.
        """
        
        # Test for an obscure deadlocking bug that can occur if we receive
        # certain messages while initiating a key exchange.
        #
        # The deadlock occurs as follows:
        #
        # In the main thread:
        #   1. The user's program calls Channel.send(), which sends
        #      MSG_CHANNEL_DATA to the remote host.
        #   2. Packetizer discovers that REKEY_BYTES has been exceeded, and
        #      sets the __need_rekey flag.
        #
        # In the Transport thread:
        #   3. Packetizer notices that the __need_rekey flag is set, and raises
        #      NeedRekeyException.
        #   4. In response to NeedRekeyException, the transport thread sends
        #      MSG_KEXINIT to the remote host.
        # 
        # On the remote host (using any SSH implementation):
        #   5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent.
        #   6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent.
        #
        # In the main thread:
        #   7. The user's program calls Channel.send().
        #   8. Channel.send acquires Channel.lock, then calls Transport._send_user_message().
        #   9. Transport._send_user_message waits for Transport.clear_to_send
        #      to be set (i.e., it waits for re-keying to complete).
        #      Channel.lock is still held.
        #
        # In the Transport thread:
        #   10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust
        #       is called to handle it.
        #   11. Channel._window_adjust tries to acquire Channel.lock, but it
        #       blocks because the lock is already held by the main thread.
        #
        # The result is that the Transport thread never processes the remote
        # host's MSG_KEXINIT packet, because it becomes deadlocked while
        # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message.

        # We set up two separate threads for sending and receiving packets,
        # while the main thread acts as a watchdog timer.  If the timer
        # expires, a deadlock is assumed.

        class SendThread(threading.Thread):
            def __init__(self, chan, iterations, done_event):
                threading.Thread.__init__(self, None, None, self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.iterations = iterations
                self.done_event = done_event
                self.watchdog_event = threading.Event()
                self.last = None
            
            def run(self):
                try:
                    for i in range(1, 1+self.iterations):
                        if self.done_event.isSet():
                            break
                        self.watchdog_event.set()
                        #print i, "SEND"
                        self.chan.send(b"x" * 2048)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()
        
        class ReceiveThread(threading.Thread):
            def __init__(self, chan, done_event):
                threading.Thread.__init__(self, None, None, self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.done_event = done_event
                self.watchdog_event = threading.Event()
            
            def run(self):
                try:
                    while not self.done_event.isSet():
                        if self.chan.recv_ready():
                            chan.recv(65536)
                            self.watchdog_event.set()
                        else:
                            if random.randint(0, 1):
                                time.sleep(random.randint(0, 500) / 1000.0)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()
        
        self.setup_test_server()
        self.ts.packetizer.REKEY_BYTES = 2048
        
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        # Monkey patch the client's Transport._handler_table so that the client
        # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
        # MSG_KEXINIT.  This is used to simulate the effect of network latency
        # on a real MSG_CHANNEL_WINDOW_ADJUST message.
        self.tc._handler_table = self.tc._handler_table.copy()  # copy per-class dictionary
        _negotiate_keys = self.tc._handler_table[MSG_KEXINIT]
        def _negotiate_keys_wrapper(self, m):
            if self.local_kex_init is None: # Remote side sent KEXINIT
                # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
                # before responding to the incoming MSG_KEXINIT.
                m2 = Message()
                m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST).encode())
                m2.add_int(chan.remote_chanid)
                m2.add_int(1)    # bytes to add
                self._send_message(m2)
            return _negotiate_keys(self, m)
        self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper
        
        # Parameters for the test
        iterations = 500    # The deadlock does not happen every time, but it
                            # should after many iterations.
        timeout = 5

        # This event is set when the test is completed
        done_event = threading.Event()

        # Start the sending thread
        st = SendThread(schan, iterations, done_event)
        st.start()
        
        # Start the receiving thread
        rt = ReceiveThread(chan, done_event)
        rt.start()

        # Act as a watchdog timer, checking 
        deadlocked = False
        while not deadlocked and not done_event.isSet():
            for event in (st.watchdog_event, rt.watchdog_event):
                event.wait(timeout)
                if done_event.isSet():
                    break
                if not event.isSet():
                    deadlocked = True
                    break
                event.clear()
        
        # Tell the threads to stop (if they haven't already stopped).  Note
        # that if one or more threads are deadlocked, they might hang around
        # forever (until the process exits).
        done_event.set()

        # Assertion: We must not have detected a timeout.
        self.assertFalse(deadlocked)

        # Close the channels
        schan.close()
        chan.close()
Beispiel #57
0
class TransportTest(unittest.TestCase):
    def setUp(self):
        self.socks = LoopSocket()
        self.sockc = LoopSocket()
        self.sockc.link(self.socks)
        self.tc = Transport(self.sockc)
        self.ts = Transport(self.socks)

    def tearDown(self):
        self.tc.close()
        self.ts.close()
        self.socks.close()
        self.sockc.close()

    def setup_test_server(
        self, client_options=None, server_options=None, connect_kwargs=None,
    ):
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)

        if client_options is not None:
            client_options(self.tc.get_security_options())
        if server_options is not None:
            server_options(self.ts.get_security_options())

        event = threading.Event()
        self.server = NullServer()
        self.assertTrue(not event.is_set())
        self.ts.start_server(event, self.server)
        if connect_kwargs is None:
            connect_kwargs = dict(
                hostkey=public_host_key,
                username='******',
                password='******',
            )
        self.tc.connect(**connect_kwargs)
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_1_security_options(self):
        o = self.tc.get_security_options()
        self.assertEqual(type(o), SecurityOptions)
        self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
        o.ciphers = ('aes256-cbc', 'blowfish-cbc')
        self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
        try:
            o.ciphers = ('aes256-cbc', 'made-up-cipher')
            self.assertTrue(False)
        except ValueError:
            pass
        try:
            o.ciphers = 23
            self.assertTrue(False)
        except TypeError:
            pass

    def test_1b_security_options_reset(self):
        o = self.tc.get_security_options()
        # should not throw any exceptions
        o.ciphers = o.ciphers
        o.digests = o.digests
        o.key_types = o.key_types
        o.kex = o.kex
        o.compression = o.compression

    def test_2_compute_key(self):
        self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
        self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
        self.tc.session_id = self.tc.H
        key = self.tc._compute_key('C', 32)
        self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
                          hexlify(key).upper())

    def test_3_simple(self):
        """
        verify that we can establish an ssh link with ourselves across the
        loopback sockets.  this is hardly "simple" but it's simpler than the
        later tests. :)
        """
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.assertEqual(None, self.tc.get_username())
        self.assertEqual(None, self.ts.get_username())
        self.assertEqual(False, self.tc.is_authenticated())
        self.assertEqual(False, self.ts.is_authenticated())
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())
        self.assertEqual('slowdive', self.tc.get_username())
        self.assertEqual('slowdive', self.ts.get_username())
        self.assertEqual(True, self.tc.is_authenticated())
        self.assertEqual(True, self.ts.is_authenticated())

    def test_3a_long_banner(self):
        """
        verify that a long banner doesn't mess up the handshake.
        """
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.socks.send(LONG_BANNER)
        self.ts.start_server(event, server)
        self.tc.connect(hostkey=public_host_key,
                        username='******', password='******')
        event.wait(1.0)
        self.assertTrue(event.is_set())
        self.assertTrue(self.ts.is_active())

    def test_4_special(self):
        """
        verify that the client can demand odd handshake settings, and can
        renegotiate keys in mid-stream.
        """
        def force_algorithms(options):
            options.ciphers = ('aes256-cbc',)
            options.digests = ('hmac-md5-96',)
        self.setup_test_server(client_options=force_algorithms)
        self.assertEqual('aes256-cbc', self.tc.local_cipher)
        self.assertEqual('aes256-cbc', self.tc.remote_cipher)
        self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
        self.assertEqual(12, self.tc.packetizer.get_mac_size_in())

        self.tc.send_ignore(1024)
        self.tc.renegotiate_keys()
        self.ts.send_ignore(1024)

    @slow
    def test_5_keepalive(self):
        """
        verify that the keepalive will be sent.
        """
        self.setup_test_server()
        self.assertEqual(None, getattr(self.server, '_global_request', None))
        self.tc.set_keepalive(1)
        time.sleep(2)
        self.assertEqual('*****@*****.**', self.server._global_request)

    def test_6_exec_command(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        try:
            chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string')
            self.assertTrue(False)
        except SSHException:
            pass

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('', f.readline())
        f = chan.makefile_stderr()
        self.assertEqual('This is on stderr.\n', f.readline())
        self.assertEqual('', f.readline())

        # now try it with combined stdout/stderr
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)
        schan.send('Hello there.\n')
        schan.send_stderr('This is on stderr.\n')
        schan.close()

        chan.set_combine_stderr(True)
        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('This is on stderr.\n', f.readline())
        self.assertEqual('', f.readline())
        
    def test_6a_channel_can_be_used_as_context_manager(self):
        """
        verify that exec_command() does something reasonable.
        """
        self.setup_test_server()

        with self.tc.open_session() as chan:
            with self.ts.accept(1.0) as schan:
                chan.exec_command('yes')
                schan.send('Hello there.\n')
                schan.close()

                f = chan.makefile()
                self.assertEqual('Hello there.\n', f.readline())
                self.assertEqual('', f.readline())

    def test_7_invoke_shell(self):
        """
        verify that invoke_shell() does something reasonable.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        chan.send('communist j. cat\n')
        f = schan.makefile()
        self.assertEqual('communist j. cat\n', f.readline())
        chan.close()
        self.assertEqual('', f.readline())

    def test_8_channel_exception(self):
        """
        verify that ChannelException is thrown for a bad open-channel request.
        """
        self.setup_test_server()
        try:
            chan = self.tc.open_channel('bogus')
            self.fail('expected exception')
        except ChannelException as e:
            self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)

    def test_9_exit_status(self):
        """
        verify that get_exit_status() works.
        """
        self.setup_test_server()

        chan = self.tc.open_session()
        schan = self.ts.accept(1.0)
        chan.exec_command('yes')
        schan.send('Hello there.\n')
        self.assertTrue(not chan.exit_status_ready())
        # trigger an EOF
        schan.shutdown_read()
        schan.shutdown_write()
        schan.send_exit_status(23)
        schan.close()

        f = chan.makefile()
        self.assertEqual('Hello there.\n', f.readline())
        self.assertEqual('', f.readline())
        count = 0
        while not chan.exit_status_ready():
            time.sleep(0.1)
            count += 1
            if count > 50:
                raise Exception("timeout")
        self.assertEqual(23, chan.recv_exit_status())
        chan.close()

    def test_A_select(self):
        """
        verify that select() on a channel works.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send('hello\n')

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b'hello\n', chan.recv(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()

        # detect eof?
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)
        self.assertEqual(bytes(), chan.recv(16))

        # make sure the pipe is still open for now...
        p = chan._pipe
        self.assertEqual(False, p._closed)
        chan.close()
        # ...and now is closed.
        self.assertEqual(True, p._closed)

    def test_B_renegotiate(self):
        """
        verify that a transport can correctly renegotiate mid-stream.
        """
        self.setup_test_server()
        self.tc.packetizer.REKEY_BYTES = 16384
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        self.assertEqual(self.tc.H, self.tc.session_id)
        for i in range(20):
            chan.send('x' * 1024)
        chan.close()

        # allow a few seconds for the rekeying to complete
        for i in range(50):
            if self.tc.H != self.tc.session_id:
                break
            time.sleep(0.1)
        self.assertNotEqual(self.tc.H, self.tc.session_id)

        schan.close()

    def test_C_compression(self):
        """
        verify that zlib compression is basically working.
        """
        def force_compression(o):
            o.compression = ('zlib',)
        self.setup_test_server(force_compression, force_compression)
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        bytes = self.tc.packetizer._Packetizer__sent_bytes
        chan.send('x' * 1024)
        bytes2 = self.tc.packetizer._Packetizer__sent_bytes
        block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size']
        mac_size = self.tc._mac_info[self.tc.local_mac]['size']
        # tests show this is actually compressed to *52 bytes*!  including packet overhead!  nice!! :)
        self.assertTrue(bytes2 - bytes < 1024)
        self.assertEqual(16 + block_size + mac_size, bytes2 - bytes)

        chan.close()
        schan.close()

    def test_D_x11(self):
        """
        verify that an x11 port can be requested and opened.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        requested = []
        def handler(c, addr_port):
            addr, port = addr_port
            requested.append((addr, port))
            self.tc._queue_incoming_channel(c)

        self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
        cookie = chan.request_x11(0, single_connection=True, handler=handler)
        self.assertEqual(0, self.server._x11_screen_number)
        self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
        self.assertEqual(cookie, self.server._x11_auth_cookie)
        self.assertEqual(True, self.server._x11_single_connection)

        x11_server = self.ts.open_x11_channel(('localhost', 6093))
        x11_client = self.tc.accept()
        self.assertEqual('localhost', requested[0][0])
        self.assertEqual(6093, requested[0][1])

        x11_server.send('hello')
        self.assertEqual(b'hello', x11_client.recv(5))

        x11_server.close()
        x11_client.close()
        chan.close()
        schan.close()

    def test_E_reverse_port_forwarding(self):
        """
        verify that a client can ask the server to open a reverse port for
        forwarding.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        requested = []
        def handler(c, origin_addr_port, server_addr_port):
            requested.append(origin_addr_port)
            requested.append(server_addr_port)
            self.tc._queue_incoming_channel(c)

        port = self.tc.request_port_forward('127.0.0.1', 0, handler)
        self.assertEqual(port, self.server._listen.getsockname()[1])

        cs = socket.socket()
        cs.connect(('127.0.0.1', port))
        ss, _ = self.server._listen.accept()
        sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername())
        cch = self.tc.accept()

        sch.send('hello')
        self.assertEqual(b'hello', cch.recv(5))
        sch.close()
        cch.close()
        ss.close()
        cs.close()

        # now cancel it.
        self.tc.cancel_port_forward('127.0.0.1', port)
        self.assertTrue(self.server._listen is None)

    def test_F_port_forwarding(self):
        """
        verify that a client can forward new connections from a locally-
        forwarded port.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        # open a port on the "server" that the client will ask to forward to.
        greeting_server = socket.socket()
        greeting_server.bind(('127.0.0.1', 0))
        greeting_server.listen(1)
        greeting_port = greeting_server.getsockname()[1]

        cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000))
        sch = self.ts.accept(1.0)
        cch = socket.socket()
        cch.connect(self.server._tcpip_dest)

        ss, _ = greeting_server.accept()
        ss.send(b'Hello!\n')
        ss.close()
        sch.send(cch.recv(8192))
        sch.close()

        self.assertEqual(b'Hello!\n', cs.recv(7))
        cs.close()

    def test_G_stderr_select(self):
        """
        verify that select() on a channel works even if only stderr is
        receiving data.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        # nothing should be ready
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.send_stderr('hello\n')

        # something should be ready now (give it 1 second to appear)
        for i in range(10):
            r, w, e = select.select([chan], [], [], 0.1)
            if chan in r:
                break
            time.sleep(0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        self.assertEqual(b'hello\n', chan.recv_stderr(6))

        # and, should be dead again now
        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

        schan.close()
        chan.close()

    def test_H_send_ready(self):
        """
        verify that send_ready() indicates when a send would not block.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)

        self.assertEqual(chan.send_ready(), True)
        total = 0
        K = '*' * 1024
        limit = 1+(64 * 2 ** 15)
        while total < limit:
            chan.send(K)
            total += len(K)
            if not chan.send_ready():
                break
        self.assertTrue(total < limit)

        schan.close()
        chan.close()
        self.assertEqual(chan.send_ready(), True)

    def test_I_rekey_deadlock(self):
        """
        Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent

        Note: When this test fails, it may leak threads.
        """

        # Test for an obscure deadlocking bug that can occur if we receive
        # certain messages while initiating a key exchange.
        #
        # The deadlock occurs as follows:
        #
        # In the main thread:
        #   1. The user's program calls Channel.send(), which sends
        #      MSG_CHANNEL_DATA to the remote host.
        #   2. Packetizer discovers that REKEY_BYTES has been exceeded, and
        #      sets the __need_rekey flag.
        #
        # In the Transport thread:
        #   3. Packetizer notices that the __need_rekey flag is set, and raises
        #      NeedRekeyException.
        #   4. In response to NeedRekeyException, the transport thread sends
        #      MSG_KEXINIT to the remote host.
        #
        # On the remote host (using any SSH implementation):
        #   5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent.
        #   6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent.
        #
        # In the main thread:
        #   7. The user's program calls Channel.send().
        #   8. Channel.send acquires Channel.lock, then calls Transport._send_user_message().
        #   9. Transport._send_user_message waits for Transport.clear_to_send
        #      to be set (i.e., it waits for re-keying to complete).
        #      Channel.lock is still held.
        #
        # In the Transport thread:
        #   10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust
        #       is called to handle it.
        #   11. Channel._window_adjust tries to acquire Channel.lock, but it
        #       blocks because the lock is already held by the main thread.
        #
        # The result is that the Transport thread never processes the remote
        # host's MSG_KEXINIT packet, because it becomes deadlocked while
        # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message.

        # We set up two separate threads for sending and receiving packets,
        # while the main thread acts as a watchdog timer.  If the timer
        # expires, a deadlock is assumed.

        class SendThread(threading.Thread):
            def __init__(self, chan, iterations, done_event):
                threading.Thread.__init__(self, None, None, self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.iterations = iterations
                self.done_event = done_event
                self.watchdog_event = threading.Event()
                self.last = None

            def run(self):
                try:
                    for i in range(1, 1+self.iterations):
                        if self.done_event.is_set():
                            break
                        self.watchdog_event.set()
                        #print i, "SEND"
                        self.chan.send("x" * 2048)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        class ReceiveThread(threading.Thread):
            def __init__(self, chan, done_event):
                threading.Thread.__init__(self, None, None, self.__class__.__name__)
                self.setDaemon(True)
                self.chan = chan
                self.done_event = done_event
                self.watchdog_event = threading.Event()

            def run(self):
                try:
                    while not self.done_event.is_set():
                        if self.chan.recv_ready():
                            chan.recv(65536)
                            self.watchdog_event.set()
                        else:
                            if random.randint(0, 1):
                                time.sleep(random.randint(0, 500) / 1000.0)
                finally:
                    self.done_event.set()
                    self.watchdog_event.set()

        self.setup_test_server()
        self.ts.packetizer.REKEY_BYTES = 2048

        chan = self.tc.open_session()
        chan.exec_command('yes')
        schan = self.ts.accept(1.0)

        # Monkey patch the client's Transport._handler_table so that the client
        # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
        # MSG_KEXINIT.  This is used to simulate the effect of network latency
        # on a real MSG_CHANNEL_WINDOW_ADJUST message.
        self.tc._handler_table = self.tc._handler_table.copy()  # copy per-class dictionary
        _negotiate_keys = self.tc._handler_table[MSG_KEXINIT]
        def _negotiate_keys_wrapper(self, m):
            if self.local_kex_init is None: # Remote side sent KEXINIT
                # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
                # before responding to the incoming MSG_KEXINIT.
                m2 = Message()
                m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
                m2.add_int(chan.remote_chanid)
                m2.add_int(1)    # bytes to add
                self._send_message(m2)
            return _negotiate_keys(self, m)
        self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper

        # Parameters for the test
        iterations = 500    # The deadlock does not happen every time, but it
                            # should after many iterations.
        timeout = 5

        # This event is set when the test is completed
        done_event = threading.Event()

        # Start the sending thread
        st = SendThread(schan, iterations, done_event)
        st.start()

        # Start the receiving thread
        rt = ReceiveThread(chan, done_event)
        rt.start()

        # Act as a watchdog timer, checking
        deadlocked = False
        while not deadlocked and not done_event.is_set():
            for event in (st.watchdog_event, rt.watchdog_event):
                event.wait(timeout)
                if done_event.is_set():
                    break
                if not event.is_set():
                    deadlocked = True
                    break
                event.clear()

        # Tell the threads to stop (if they haven't already stopped).  Note
        # that if one or more threads are deadlocked, they might hang around
        # forever (until the process exits).
        done_event.set()

        # Assertion: We must not have detected a timeout.
        self.assertFalse(deadlocked)

        # Close the channels
        schan.close()
        chan.close()

    def test_J_sanitze_packet_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [(4095, MIN_PACKET_SIZE),
                             (None, DEFAULT_MAX_PACKET_SIZE),
                             (2**32, MAX_WINDOW_SIZE)]:
            self.assertEqual(self.tc._sanitize_packet_size(val), correct)

    def test_K_sanitze_window_size(self):
        """
        verify that we conform to the rfc of packet and window sizes.
        """
        for val, correct in [(32767, MIN_WINDOW_SIZE),
                             (None, DEFAULT_WINDOW_SIZE),
                             (2**32, MAX_WINDOW_SIZE)]:
            self.assertEqual(self.tc._sanitize_window_size(val), correct)

    @slow
    def test_L_handshake_timeout(self):
        """
        verify that we can get a hanshake timeout.
        """
        # Tweak client Transport instance's Packetizer instance so
        # its read_message() sleeps a bit. This helps prevent race conditions
        # where the client Transport's timeout timer thread doesn't even have
        # time to get scheduled before the main client thread finishes
        # handshaking with the server.
        # (Doing this on the server's transport *sounds* more 'correct' but
        # actually doesn't work nearly as well for whatever reason.)
        class SlowPacketizer(Packetizer):
            def read_message(self):
                time.sleep(1)
                return super(SlowPacketizer, self).read_message()
        # NOTE: prettttty sure since the replaced .packetizer Packetizer is now
        # no longer doing anything with its copy of the socket...everything'll
        # be fine. Even tho it's a bit squicky.
        self.tc.packetizer = SlowPacketizer(self.tc.sock)
        # Continue with regular test red tape.
        host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
        public_host_key = RSAKey(data=host_key.asbytes())
        self.ts.add_server_key(host_key)
        event = threading.Event()
        server = NullServer()
        self.assertTrue(not event.is_set())
        self.tc.handshake_timeout = 0.000000000001
        self.ts.start_server(event, server)
        self.assertRaises(EOFError, self.tc.connect,
                          hostkey=public_host_key,
                          username='******',
                          password='******')

    def test_M_select_after_close(self):
        """
        verify that select works when a channel is already closed.
        """
        self.setup_test_server()
        chan = self.tc.open_session()
        chan.invoke_shell()
        schan = self.ts.accept(1.0)
        schan.close()

        # give client a moment to receive close notification
        time.sleep(0.1)

        r, w, e = select.select([chan], [], [], 0.1)
        self.assertEqual([chan], r)
        self.assertEqual([], w)
        self.assertEqual([], e)

    def test_channel_send_misc(self):
        """
        verify behaviours sending various instances to a channel
        """
        self.setup_test_server()
        text = u"\xa7 slice me nicely"
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # TypeError raised on non string or buffer type
            self.assertRaises(TypeError, chan.send, object())
            self.assertRaises(TypeError, chan.sendall, object())

            # sendall() accepts a unicode instance
            chan.sendall(text)
            expected = text.encode("utf-8")
            self.assertEqual(sfile.read(len(expected)), expected)

    @needs_builtin('buffer')
    def test_channel_send_buffer(self):
        """
        verify sending buffer instances to a channel
        """
        self.setup_test_server()
        data = 3 * b'some test data\n whole'
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts buffer instances
            sent = 0
            while sent < len(data):
                sent += chan.send(buffer(data, sent, 8))
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a buffer instance
            chan.sendall(buffer(data))
            self.assertEqual(sfile.read(len(data)), data)

    @needs_builtin('memoryview')
    def test_channel_send_memoryview(self):
        """
        verify sending memoryview instances to a channel
        """
        self.setup_test_server()
        data = 3 * b'some test data\n whole'
        with self.tc.open_session() as chan:
            schan = self.ts.accept(1.0)
            if schan is None:
                self.fail("Test server transport failed to accept")
            sfile = schan.makefile()

            # send() accepts memoryview slices
            sent = 0
            view = memoryview(data)
            while sent < len(view):
                sent += chan.send(view[sent:sent+8])
            self.assertEqual(sfile.read(len(data)), data)

            # sendall() accepts a memoryview instance
            chan.sendall(memoryview(data))
            self.assertEqual(sfile.read(len(data)), data)

    def test_server_rejects_open_channel_without_auth(self):
        try:
            self.setup_test_server(connect_kwargs={})
            self.tc.open_session()
        except ChannelException as e:
            assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
        else:
            assert False, "Did not raise ChannelException!"

    def test_server_rejects_arbitrary_global_request_without_auth(self):
        self.setup_test_server(connect_kwargs={})
        # NOTE: this dummy global request kind would normally pass muster
        # from the test server.
        self.tc.global_request('acceptable')
        # Global requests never raise exceptions, even on failure (not sure why
        # this was the original design...ugh.) Best we can do to tell failure
        # happened is that the client transport's global_response was set back
        # to None; if it had succeeded, it would be the response Message.
        err = "Unauthed global response incorrectly succeeded!"
        assert self.tc.global_response is None, err

    def test_server_rejects_port_forward_without_auth(self):
        # NOTE: at protocol level port forward requests are treated same as a
        # regular global request, but Paramiko server implements a special-case
        # method for it, so it gets its own test. (plus, THAT actually raises
        # an exception on the client side, unlike the general case...)
        self.setup_test_server(connect_kwargs={})
        try:
            self.tc.request_port_forward('localhost', 1234)
        except SSHException as e:
            assert "forwarding request denied" in str(e)
        else:
            assert False, "Did not raise SSHException!"
Beispiel #58
0
class Transporter(object):
    def __init__(self, \
                 timeout, \
                 remoteMachine, \
                 port, \
                 username, \
                 password, \
                 loggerName, \
                 pk, \
                 bufsize = 8192):
                 

        import logging
        self.logger = logging.getLogger(loggerName)
        
        self.timeout = timeout
        self.remoteMachine = remoteMachine
        self.port = port
        self.username = username
        self.password = password
        self.bufsize = bufsize
        self.pk = DSSKey.from_private_key_file(expanduser(pk))
        self.logger.info('Object succesfully created')
        self.logger.debug ('timeout = %s, remoteMachine = %s, port = %s, username = %s'%(self.timeout, self.remoteMachine, self.port, self.username))
        #util.log_to_file('paramiko.log')
        

    def Link(self,\
            sftp=None,\
            ssh=True):
        self.sock = socket()
        self.sock.settimeout(self.timeout)
        self.Errors = []
        self.protocol = []
        self.logger.info('Trying to start link using parameters [SSH=%s,SFTP=%s]'%(ssh,sftp))
        try:
            self.sock.connect((self.remoteMachine, self.port))

        except:
            self.Unlink()
            self.logger.error('Link: Could not open a socket [remoteMachine = %s, port = %s]' % (self.remoteMachine, self.port))
            

        self.transport = Transport(self.sock)

        try:
            if not self.password:
                self.transport.connect(username = self.username, pkey = self.pk)
                self.logger.info('Trying to connect using private key..')
            else:
                self.transport.connect(username = self.username, password = self.password)
                self.logger.info('Trying to connect using password..')

        except:
            if not self.password:
                self.logger.error('Private Key not installed and no password specified!')
                
            self.Unlink()
            self.logger.error('Could not connect using SFTP [remoteMachine = %s, Username = %s, port = %s]' %(self.remoteMachine, self.username, self.port))
            
            sftp=None

            
        if ssh:
            try:
                self.protocol.append(mySSHClient())
                self.ssh = self.protocol[-1]
                self.ssh.set_missing_host_key_policy(AutoAddPolicy())
                self.ssh.connect(self.remoteMachine, port = self.port, username = self.username, password = self.password)
                self.logger.info ('ssh connection created..')
            except:
                self.logger.error('Could not connect using SSH [remoteMachine = %s, Username = %s, port = %s]'%(self.remoteMachine, self.username, self.port))

        if sftp:
            try:
                if ssh == None:
                    self.logger.error ('SSH must be enabled in order to connect with SFTP')
                    
                
                self.protocol.append(SFTPClient.from_transport(self.transport))
                self.sftp = self.protocol[-1]
                self.client = self.sftp
                self.logger.info('sftp connection created..')
            except:
                self.logger.error ('Could not bring up SFTP connection')
                
                self.sftp=None
    
        return ssh, sftp
    
    def RemoteCommand (self, command, check_exit=True, OSD=None):
        try:
            if not self.ssh:
                raise Exception, 'SSH must be enabled on remote machine in order to run commands.'
            remoteHost = self.ssh
            if not command is 'close_connection':
                self.command = command
                stdin, self.output, stderr = remoteHost.exec_command(self.command, timeout=40)
                output_lines = self.output.readlines()
                fprint=''.join(output_lines)
                if OSD:
                    print fprint #On Screen Display formatted output
                self.logger.debug('output of remote command `%s` on [%s]: %s' %(self.command,\
                        self.remoteMachine,\
                        fprint))
                if check_exit:
                    errorList = stderr.readlines()
                    if len(errorList) !=0:
                        self.logger.error(errorList)
                        raise Exception, 'Failed to execute command %s on [%s]' %(self.command, \
                                self.remoteMachine)
                    else:
                        self.logger.info('Command `%s` on [%s] completed successfuly'%(self.command,\
                                self.remoteMachine))

            else:
                try:
                    self.logger.info ('Trying to close connections..')
                    self.Unlink()
                    self.logger.info ('Connections closed succesfully..')
                except:
                    self.logger.error('Could not close connections! ' )
                    
            return output_lines

        except:
             raise Exception, 'Could not complete the command %s'%(self.command)
             #raise
        
            


    def CopyToRemote(self, \
            localFolder, \
            remoteFolder, \
            fileName, \
            copyProtocol='scp', \
            scpCommand='scp -o ConnectTimeout=30 -p -P', \
            response = True):
        '''
        Gets: localFolder, remoteFolder, fileName
        Returns: Response.
        '''
        localFilePath = pJoin(localFolder, fileName)
        remoteFilePath = pJoin(remoteFolder, fileName)
        if copyProtocol == 'sftp':
            try:
                self.sftp.put(localFilePath, remoteFilePath)
                self.logger.info ('Copied file %s to %s:%s'%(localFilePath, self.remoteMachine, remoteFilePath))
        
            except AttributeError:
                self.logger.error('SFTP connection is not open')
                
                response = False
            except:
                raise Exception, 'Could not connect using SFTP'
                
               
            return response
        
        elif copyProtocol == 'scp':
            try:
                scpCopy=getstatusoutput('%s %s %s %s@%s:%s '%(self.scpCommand, self.port, localFilePath, self.username, \
                        self.remoteMachine, remoteFilePath))
                if scpCopy[0] != 0:
                    self.logger.error (scpCopy[1])
                    self.logger.error('Error %s: %s'%(scpCopy[0],self.Errors))
                    response = False
                else:
                    self.logger.info('Copied file from "%s" to "%s:%s" successfuly' % (localFilePath, \
                            self.remoteMachine, remoteFilePath))
            except:
                self.logger.error('Could not transfer file %s to [%s] using SCP'%(remoteFilePath,self.remoteMachine))
                
                response = False

        else:
            raise Exception, 'Unknown copy protocol "%s", scp and sftp support only'%copyProtocol
            response = False
        return response
            
  
    def RenameLocalFile (self, localFolder, OrigFileName, NewFileName):
        
        self.logger.info ('Renaming %s to %s'%(OrigFileName, NewFileName))
        localFilePath = pJoin(localFolder, OrigFileName)
        renamedFilePath = pJoin(localFolder, NewFileName)
        try:
            rename (localFilePath, renamedFilePath) 
        except:
            raise Exception, 'Could not rename file %s to %s'%(localFilePath, renamedFilePath)
             


    def Unlink(self):
        try:
            self.ssh.close()
        except:
            self.logger.error('No SSH client to close...')
        try:
            self.sftp.close()
        except:
            self.logger.error('No SFTP client to close...')
        try:
            self.transport.close()
        except:
            self.logger.error('No transport to close...')
        try:
            self.sock.close()
        except:
            self.logger('No socket to close...')

    def dotTmp(self, \
               localFileFolder, \
               localFileName, \
               remoteFileFolder, \
               remoteFileName, \
               index, \
               runType = 'put'):
        '''
        Gets: localFileFolder, localFileName,
              remoteFileFolder, remoteFileName,
              index, runType.
 
        Returns: True
        '''

        self.Errors = []
        response = True
        localFilePath = pJoin(localFileFolder, localFileName)
        remoteTempFilePath = pJoin(remoteFileFolder, '%s.%s.tmp' % (remoteFileName, index))
        remoteFilePath = pJoin(remoteFileFolder, '%s.%s' % (remoteFileName, index))

        try:
            if runType == 'put':
                self.client.put(localFilePath, remoteTempFilePath)
            elif runType == 'open':
                self._WriteLocalDataToRemoteFile(localFilePath, remoteTempFilePath)
        except:
            self.logger.error('Failed creating "%s" on the remote server "%s"' % (remoteFilePath, self.remoteMachine))
            response = False

        try:
            self.client.remove(remoteFilePath)
        except:
            pass

        try:
            self.client.rename(remoteTempFilePath, remoteFilePath)
        except:
            self.logger.error('Failed moving "%s" to "%s" on the remote server "%s"' % (remoteTempFilePath, remoteFilePath, self.remoteMachine))
            response = False

        if len(self.Errors) != 0:
            self.logger.error(self.Errors)
        return response


    def MD5Compare (self,
                    MD5File,
                    localFolder,
                    DPIFileType,
                    DateTimeFormat='%Y%m%d%H%M%S',
                    Collect=True,
                    Rename=True,
                    PRX_MD5Command = 'list-files',
                    copyProtocol='scp'):
        
#         '''
#         Gets: a file name made from  pickle.
#               a local folder where to files should be saved.
#         Returns: a list of changed files.
#         '''
        CurrentTimeDate=datetime.now().strftime(DateTimeFormat)
        listedFiles = self.GetMD5FromPrx(rqstType = DPIFileType, md5cmd = PRX_MD5Command)
        CurrentMD5Dict = {self.remoteMachine : listedFiles}
        
        if not path.exists(MD5File): # MD5File does not exist
            LastSavedDict = {}
        else:
            LastSavedDict = self._readMD5fromFile(MD5File)
            
        if LastSavedDict.get(self.remoteMachine, 0) == 0: # Machine not found in MD5File (First time accessing)
            LocalSavedDict = {}
            for ftype in DPIFileType:
                LocalSavedDict[ftype] = {}
        else:
            LocalSavedDict = LastSavedDict[self.remoteMachine]
            
        DPIMD5Dict = CurrentMD5Dict[self.remoteMachine]
        self.logger.debug ('Local Saved Files before copy = %s'%LocalSavedDict)
        
        if LocalSavedDict == DPIMD5Dict: ## If there was not change in files. 
            self.logger.info('All files are up to date. No files copied')
        else:
            changedFiles = []
            for ftype in DPIFileType:
                for fileName, md5 in DPIMD5Dict.get(ftype, {}).iteritems():
                    if md5 != LocalSavedDict.get(ftype, {}).get(fileName, ''):
                    #if LocalSavedDict.get(ftype, {}).get(md5, None):
                        if Collect:
                            self.Collector(localFolder=localFolder, remoteFolder=ftype, copyProtocol=copyProtocol,collectedFileName=fileName,collectedSuffix='',\
                                                       getLatestOnly = False, removeRemoteFiles = False)
                        if Rename:
                            (name, ext) = path.splitext(fileName)
                            self.RenameLocalFile(OrigFileName = fileName, NewFileName = name+'.'+self.remoteMachine+'.'+CurrentTimeDate+ext+'.report', \
                                                localFolder=localFolder)
                        if ftype not in LocalSavedDict: LocalSavedDict[ftype] = {}
                        LocalSavedDict[ftype][fileName] = md5
                        changedFiles.append(fileName)
            
            LastSavedDict[self.remoteMachine] = LocalSavedDict
            self.logger.debug('Local Saved Files after copy = %s'%LocalSavedDict)
            self._writeMD5toFile(MD5File,LastSavedDict)
            
    def CreatePrxStat(self,
                      remoteMachine,
                      localFolder,
                      DPI_Statistics,
                      filename,
                      Prefix,
                      Suffix,
                      cmdTimeout,
                      DateTimeFormat='%Y%m%d%H%M%S'):
        
        
        ''' Create statistic file from command (list-protocols --statistics) '''
        CurrentTimeDate=datetime.now().strftime(DateTimeFormat)
        command = self.RemoteCommand(DPI_Statistics, cmdTimeout)
        timestamp = int(time.time())
        
        try:
            f = '%s-%s.%s.%s%s' % (Prefix, filename, remoteMachine, CurrentTimeDate, Suffix)
            statfile = open(localFolder + f, 'w')
            #statfile.write('timestamp: ' + str(timestamp) + os.linesep)
            #statfile.write('system-id: ' + remoteMachine + os.linesep)
            statfile.write('timestamp: ' + str(timestamp) + '\n')
            statfile.write('system-id: ' + remoteMachine + '\n')
            for linestat in command:
                statfile.write(linestat)
            statfile.close()
        except:
            self.logger.error('Failed to Create %s file' % f)
            raise


    def dotDone(self, \
                localFileFolder, \
                localFileName, \
                remoteFileFolder, \
                remoteFileName, \
                doneFileFolder, \
                doneFileName, \
                doneFileSuffix, \
                index, \
                runType = 'put'):
        '''
        Gets: localFileFolder, localFileName,
              remoteFileFolder, remoteFileName,
              doneFileFolder, doneFileName,
              doneFileSuffix, index,
              runType.

        Returns: True
        '''

        self.Errors = []
        response = True
        localFilePath = pJoin(localFileFolder, localFileName)
        doneFilePath = pJoin(doneFileFolder, '%s.%s.%s' % (doneFileName, doneFileSuffix, index))
        remoteFilePath = pJoin(remoteFileFolder, '%s.%s' % (remoteFileName, index))
        remoteDoneFilePath = pJoin(doneFileFolder, '%s.delta.%s' % (remoteFileName, index))

        try:
            if runType == 'put':
                self.client.put(localFilePath, remoteFilePath)
                self.client.put(localFilePath, remoteDoneFilePath)
            elif runType == 'open':
                self._WriteLocalDataToRemoteFile(localFilePath, remoteFilePath)
                self._WriteLocalDataToRemoteFile(localFilePath, remoteDoneFilePath)
        except:
            self.logger.error('Failed creating "%s" on the remote server "%s"' % (remoteFilePath, self.remoteMachine))
            response = False

        try:
            self.client.open(doneFilePath, 'w').close()
        except:
            self.logger.error('Failed creating done file "%s" on the remote server "%s"' % (doneFilePath, self.remoteMachine))
            response = False

        if len(self.Errors) != 0:
            self.logger.error(self.Errors)

        return response

    
    def Collector(self, \
                  localFolder, \
                  remoteFolder, \
                  copyProtocol = 'sftp', \
                  collectedFileName = '', \
                  collectedSuffix = '', \
                  getLatestOnly = True, \
                  removeRemoteFiles = True):
        '''
        Gets: localFolder, remoteFolder,
              collectedFileName, collectedSuffix,
              getLatestOnly, removeRemoteFiles.
        Returns: True
        '''

        self.Errors = []
        response = True
        CollectedFiles = []
        RemoteFilesDict = {}
        if copyProtocol == 'sftp':
            if collectedFileName != '' and collectedSuffix == '':
                CollectedFiles.append(collectedFileName)
                remoteFilePath = pJoin(remoteFolder, collectedFileName)
                try:
                    RemoteFilesDict[0] = self.client.stat(remoteFilePath)
                    self.logger.debug ('Collected file name: %s'%(collectedFileName))
                except:
                    self.logger.error ('Could not get file stat for %s'%(collectedFileName))
                    

            elif collectedSuffix != '' and collectedFileName == '':
                RemoteFiles = self.client.listdir(remoteFolder)
                self.logger.debug ('Files in remote dir: %s'%(RemoteFiles))
                try:
                    suffix = search('(.*)(\.)(.*)', collectedSuffix).group(3)
                    self.logger.debug ('Filtered Suffix: %s'%(suffix))
                except:
                    self.logger.error('Invalid files suffix given: %s' % collectedSuffix)
                    
                    response = False

                for remoteFile in RemoteFiles:
                    try:
                        search('\.%s$' % suffix, remoteFile).group()
                        CollectedFiles.append(remoteFile)
                        remoteFilePath = pJoin(remoteFolder, remoteFile)
                        RemoteFilesDict[remoteFile] = self.client.stat(remoteFilePath)
                    except:
                        pass
                    #if remoteFile.endswith(collectedSuffix):
                    #    CollectedFiles.append(remoteFile)
                    #    remoteFilePath = pJoin(remoteFolder, remoteFile)
                    #    RemoteFilesDict[remoteFile] = self.client.stat(remoteFilePath)
                self.logger.debug ('Remote Files Dic: %s'%RemoteFilesDict)
            else:
                raise Exception, 'No remote file name or suffix specified'
                response = False
        
            if getLatestOnly == True and len(RemoteFilesDict) != 0:
                latestRemoteFile = self._GetLatestFile(RemoteFilesDict)
                self._GetRemoteFile(localFolder, remoteFolder, latestRemoteFile, copyProtocol)
                if removeRemoteFiles == True:
                    self._RemoveRemoteFile(remoteFolder, latestRemoteFile)
            elif getLatestOnly == False and len(RemoteFilesDict) != 0:
                for collectedFile in CollectedFiles:
                    try:
                        self._GetRemoteFile(localFolder, remoteFolder, collectedFile, copyProtocol=copyProtocol)
                        
                    except:
                        self.logger.error('Failed getting "%s" from remote host "%s" to local folder "%s"' \
                                       % (pJoin(remoteFolder, collectedFile), self.remoteMachine, localFolder))
                        
                        response = False
                    if removeRemoteFiles == True:
                        self._RemoveRemoteFile(remoteFolder, collectedFile)
            else:
                response = False
                self.logger.debug ('%s'%(len(RemoteFilesDict)))
            return response

        elif copyProtocol == 'scp':
            if getLatestOnly == True:
                self.logger.error ('Get latest only is not supported when using scp')
                getLatestOnly == False
            if removeRemoteFiles == True:
                self.logger.error ('Remote file removal is not supported when using scp')
                removeRemoteFiles == False
            try:
                self._GetRemoteFile(localFolder, remoteFolder, collectedFileName, copyProtocol=copyProtocol)
            except:
                response = False
                raise Exception, ('Could not get file %s from %s to %s'%(collectedFileName, remoteFolder, localFolder))
                
            return response
                
    
    def _writeMD5toFile(self, MD5File, dictionary={}):
        ## Write MD5 dictionary to file
        try:
            MD5ToSave = open (MD5File, 'wb')
            pickle.dump(dictionary, MD5ToSave)
            MD5ToSave.close()
            return MD5ToSave
        except:
            self.logger.error ('could not write MD5 file:\n%s:%s'%(sys.exc_info()[0],sys.exc_info()[1]))
    
    def _readMD5fromFile (self, MD5File):
        ## Read last saved MD5 dictionary from file. 
        LastSavedFile = open (MD5File,'rb')
        SavedMD5Dict = pickle.load(LastSavedFile)
        LastSavedFile.close()
        return SavedMD5Dict
    
                    
    def _WriteLocalDataToRemoteFile(self, localFilePath, remoteFilePath):
        '''
        Gets: localFilePath, remoteFilePath.
        Returns: N/A
        '''

        fh = open(localFilePath, 'r')
        Lines = fh.readlines()
        fh.close()

        rfh = self.client.open(filename = remoteFilePath, mode = 'w', bufsize = self.bufsize)
        for line in Lines:
            rfh.write(line)
        rfh.close()

    def _GetRemoteFile(self, localFolder, remoteFolder, collectedFileName, copyProtocol):
        '''
        Gets: localFolder, remoteFolder, collectedFileName.
        Returns: N/A
        '''
        remoteFilePath = pJoin(remoteFolder, collectedFileName)
        localFilePath = pJoin(localFolder, collectedFileName)
        if copyProtocol == 'sftp':
            try:
                self.client.get(remoteFilePath, localFilePath)
                self.logger.info('Copied file from "%s:%s" to "%s" successfuly' % (self.remoteMachine, remoteFilePath,\
                        localFilePath))
            except:
                self.logger.error('Could not get file %s from [%s] using SFTP'%(remoteFilePath,self.remoteMachine))
                
        
        elif copyProtocol == 'scp':
            scpCopy=getstatusoutput('scp -o ConnectTimeout=5 -p -P %s %s@%s:%s %s'%(self.port, self.username, self.remoteMachine, \
                    remoteFilePath, localFilePath))
            if scpCopy[0] != 0:
                self.logger.error (scpCopy[1])
                self.logger.error('Could not get file %s from [%s] using SCP'%(remoteFilePath,self.remoteMachine))
                self.logger.error(self.Errors)
                print ('%s'%self.Errors[0])

            else:
                self.logger.info('Copied file from "%s:%s" to "%s" successfuly' % (self.remoteMachine, \
                        remoteFilePath, localFilePath))
        else:
            self.logger.error('Unknown copy protocol %s, scp and sftp support only'%copyProtocol)


    def GetMD5FromPrx (self, rqstType, md5cmd='list-files'):
        try:
            fileType = {}
            for r in rqstType:
                fileType[r] = {}
            command = self.RemoteCommand(md5cmd)              
            found = False
            ignoreFileType = True
            for line in command:
                lineList = line.split()
                for fType in fileType: #.keys():
                    if fType == line[1:].strip()[:-1]:
                        if fType not in rqstType:
                            ignoreFileType = True
                        else:
                            ignoreFileType = False
                        type = fType
                        found = True
                        break
                    else:
                        found = False
                if found: continue
                if lineList and not ignoreFileType:
                    try:
                        #fileType[type].update({lineList[-2] : lineList[-1]}) ## {md5:file}
                        fileType[type].update({lineList[-1] : lineList[-2]})  ## {file:md5}
                    except IndexError:
                        ignoreFileType = True
            self.logger.debug ('File Dictionary on PRX %s'%fileType) 
            return fileType
        except:
            self.logger.error ('Could not get md5cmd from PRX: %s:%s'%(sys.exc_info()[0],sys.exc_info()[1]))

        
    def _GetLatestFile(self, SftpClientFilesDict):
        '''
        Gets: SftpClientFilesDict.
        Returns: latestFileName.
        '''
        latest = 0
        for sftpClientFile in SftpClientFilesDict.iterkeys():
            if SftpClientFilesDict[sftpClientFile].st_mtime > latest:
                latestFileName = sftpClientFile
                latest = SftpClientFilesDict[sftpClientFile].st_mtime
        return latestFileName


    def _RemoveRemoteFile(self, remoteFileFolder, remoteFile):
        '''
        Gets: remoteFileFolder, remoteFile.
        Returns: N/A.
        '''
        remoteFilePath = pJoin(remoteFileFolder, remoteFile)
        try:
            self.client.remove(remoteFilePath)
            self.logger.info('Removing remote file "%s"' % remoteFilePath)
        except:
            raise Exception, 'Failed removing remote file "%s"...' % remoteFilePath
Beispiel #59
0
class IrmaSFTP(IrmaFTP):
    """Irma SFTP handler

    This class handles the connection with a sftp server
    functions for interacting with it.
    """
    # ==================================
    #  Constructor and Destructor stuff
    # ==================================
    def __init__(self, host, port, user, passwd,
                 dst_user=None, upload_path='uploads'):
        super(IrmaSFTP, self).__init__(host, port, user,
                                       passwd, dst_user, upload_path)
        self._client = None
        self._connect()

    # =================
    #  Private methods
    # =================

    def _connect(self):
        if self._conn is not None:
            log.warn("Already connected to sftp server")
            return
        try:
            self._conn = Transport((self._host, self._port))
            self._conn.window_size = pow(2, 27)
            self._conn.packetizer.REKEY_BYTES = pow(2, 32)
            self._conn.packetizer.REKEY_PACKETS = pow(2, 32)
            self._conn.connect(username=self._user, password=self._passwd)
            self._client = SFTPClient.from_transport(self._conn)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    # ================
    #  Public methods
    # ================

    def mkdir(self, path):
        try:
            dst_path = self._get_realpath(path)
            self._client.mkdir(dst_path)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))
        return

    def list(self, path):
        """ list remote directory <path>"""
        try:
            dst_path = self._get_realpath(path)
            return self._client.listdir(dst_path)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def upload_fobj(self, path, fobj):
        """ Upload <data> to remote directory <path>"""
        try:
            dstname = self._hash(fobj)
            path = self._tweaked_join(path, dstname)
            dstpath = self._get_realpath(path)
            self._client.putfo(fobj, dstpath)
            return dstname
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def download_fobj(self, path, remotename, fobj):
        """ returns <remotename> found in <path>"""
        # self._client.getfo(fobj, dstpath)
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, remotename)
            self._client.getfo(full_dstpath, fobj)
            # remotename is hashvalue of data
            self._check_hash(remotename, fobj)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def delete(self, path, filename):
        """ Delete <filename> into directory <path>"""
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, filename)
            self._client.remove(full_dstpath)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))

    def deletepath(self, path, deleteParent=False):
        # recursively delete all subdirs and files
        try:
            for f in self.list(path):
                if self.is_file(path, f):
                    self.delete(path, f)
                else:
                    self.deletepath(self._tweaked_join(path, f),
                                    deleteParent=True)
            if deleteParent:
                dstpath = self._get_realpath(path)
                self._client.rmdir(dstpath)
        except Exception as e:
            reason = "{0} [{1}]".format(str(e), path)
            raise IrmaSFTPError(reason)

    def is_file(self, path, filename):
        try:
            dstpath = self._get_realpath(path)
            full_dstpath = self._tweaked_join(dstpath, filename)
            st = self._client.stat(full_dstpath)
            return not stat.S_ISDIR(st.st_mode)
        except Exception as e:
            reason = "{0} [{1}]".format(e, path)
            raise IrmaSFTPError(reason)

    def rename(self, oldpath, newpath):
        try:
            old_realpath = self._get_realpath(oldpath)
            new_realpath = self._get_realpath(newpath)
            self._client.rename(old_realpath, new_realpath)
        except Exception as e:
            raise IrmaSFTPError("{0}".format(e))