Exemple #1
0
 def run(self, project):
     if not self.vm.start_VM():
         return -1, ''
     if not os.path.exists(os.path.join(project.tempdir, project.target)):
         raise FileNotFoundError('Error: Executable file has not been created!')
     copy_to_vm = [os.path.join(project.tempdir, project.target)]
     copy_from_vm = ['CUnitAutomated-Results.xml']
     print('Connecting to remote machine...')
     client = SSHClient()
     client.set_missing_host_key_policy(AutoAddPolicy())
     client.load_system_host_keys()
     client.connect(self.host, username=self.username, password=self.password, timeout=10)
     return_code = 0
     data = ''
     with client.open_sftp() as sftp:
         try:
             self.rmtree(sftp, self.remote_path)
         except FileNotFoundError:
             pass
         try:
             sftp.mkdir(self.remote_path)
         except OSError:
             pass
         for f in copy_to_vm:
             remote_file = os.path.join(self.remote_path, os.path.basename(f))
             sftp.put(f, remote_file)
             sftp.chmod(remote_file, 0o777)
             stdin, stdout, stderr = client.exec_command('cd {}; timeout {}s {}'.format(self.remote_path, self.timeout, remote_file))
             return_code = stdout.channel.recv_exit_status()
             print('[Remote] Error code: {}'.format(return_code))
             stdout_string = '[Remote] ' + ''.join(stdout)
             if stdout_string:
                 print('[Remote] STDOUT:')
                 print(stdout_string)
             stderr_string = '[Remote] ' + ''.join(stderr)
             if stderr_string:
                 print('[Remote] STDERR:')
                 print(stderr_string)
         for f in copy_from_vm:
             # get all result files
             remote_file = os.path.join(self.remote_path, os.path.basename(f))
             try:
                 with tempfile.TemporaryFile() as local_file:
                     sftp.getfo(remote_file, local_file)
                     local_file.seek(0)
                     data = local_file.read()
             except FileNotFoundError:
                 print('Remote file not found!')
         # delete all files in home directory
         self.rmtree(sftp, self.remote_path)
     client.close()
     if self.shutdown_vm_after:
         self.vm.stop_VM()
     return return_code, data
Exemple #2
0
 def ssh_connection(self, storage):
     connection = SSHClient()
     connection.load_system_host_keys()
     vals = storage.config
     connection.connect(
         hostname=vals.get('hostname', None),
         port=vals.get('port', 0),
         username=vals.get('user', None),
         password=vals.get('password', None)
     )
     sftp = connection.open_sftp()
     return connection, sftp
 def _sftp_connect(self):
     connection = SSHClient()
     connection.load_system_host_keys()
     connection.connect(
         hostname=self.integration_id.invoice_id.partner_id.ssh_server,
         port=int(self.integration_id.invoice_id.partner_id.ssh_port),
         username=self.integration_id.invoice_id.partner_id.ssh_name,
         password=self._decrypt_value(
             self.integration_id.invoice_id.partner_id.ssh_pass),
     )
     sftp = connection.open_sftp()
     return connection, sftp
Exemple #4
0
class datastore_sftp(datastore):
    """ collect of independent functions, not really a Class """
    def __init__(self):
        # Remote path for saving all resources
        self.base_folder = settings.SFTP_DATASTORE_REMOTEBASEFOLDER
        # Local base folder for saving temporary files before upload
        self.tmp_folder = settings.SFTP_DATASTORE_LOCALTMPFOLDER
        # url for donwloading resources
        self.public_base_url = settings.SFTP_BASE_URL
        self.buckets = []
        self.connection = None
        self.ssh_client = SSHClient()
        self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
        self.ssh_client.load_system_host_keys()
        self.sftp = None

    def connect(self):
        """ don't use at INIT because it hangs all application"""
        logger = logging.getLogger(__name__)
        logger.error('Connecting SFTP %s:%s (%s, %s)' %(
            settings.SFTP_DATASTORE_HOSTNAME,
            settings.SFTP_DATASTORE_PORT,
            settings.SFTP_DATASTORE_USER,
            settings.SFTP_DATASTORE_PASSWORD)
        )

        # TODO: Remove
        con = sftp.Connection(
            host=settings.SFTP_DATASTORE_HOSTNAME,
            port=settings.SFTP_DATASTORE_PORT,
            username=settings.SFTP_DATASTORE_USER,
            password=settings.SFTP_DATASTORE_PASSWORD,
            log=True
        )
        self.connection = con
        #

        self.ssh_client.connect(
            settings.SFTP_DATASTORE_HOSTNAME,
            port=settings.SFTP_DATASTORE_PORT,
            username=settings.SFTP_DATASTORE_USER,
            password=settings.SFTP_DATASTORE_PASSWORD
        )
        self.sftp = self.ssh_client.open_sftp()

        # list all buckets (folders)
        try:
            self.buckets = self.sftp.listdir(path=self.base_folder)
            logger.error('Buckets: %s' %str(self.buckets))
        except Exception, e:
            logger.error('Error Connecting SFTP %s' % str(e))
            self.sftp.close()
Exemple #5
0
class datastore_sftp(datastore):
    """ collect of independent functions, not really a Class """
    def __init__(self):
        # Remote path for saving all resources
        self.base_folder = settings.SFTP_DATASTORE_REMOTEBASEFOLDER
        # Local base folder for saving temporary files before upload
        self.tmp_folder = settings.SFTP_DATASTORE_LOCALTMPFOLDER
        # url for donwloading resources
        self.public_base_url = settings.SFTP_BASE_URL
        self.buckets = []
        self.connection = None
        self.ssh_client = SSHClient()
        self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
        self.ssh_client.load_system_host_keys()
        self.sftp = None

    def connect(self):
        """ don't use at INIT because it hangs all application"""
        logger = logging.getLogger(__name__)
        logger.info('Connecting SFTP %s:%s (%s, %s)' %(
            settings.SFTP_DATASTORE_HOSTNAME,
            settings.SFTP_DATASTORE_PORT,
            settings.SFTP_DATASTORE_USER,
            settings.SFTP_DATASTORE_PASSWORD)
        )

        # TODO: Remove
        con = sftp.Connection(
            host=settings.SFTP_DATASTORE_HOSTNAME,
            port=settings.SFTP_DATASTORE_PORT,
            username=settings.SFTP_DATASTORE_USER,
            password=settings.SFTP_DATASTORE_PASSWORD,
            log=True
        )
        self.connection = con
        #

        self.ssh_client.connect(
            settings.SFTP_DATASTORE_HOSTNAME,
            port=settings.SFTP_DATASTORE_PORT,
            username=settings.SFTP_DATASTORE_USER,
            password=settings.SFTP_DATASTORE_PASSWORD
        )
        self.sftp = self.ssh_client.open_sftp()

        # list all buckets (folders)
        try:
            self.buckets = self.sftp.listdir(path=self.base_folder)
            logger.info('Buckets: %s' %str(self.buckets))
        except Exception, e:
            logger.error('Error Connecting SFTP %s' % str(e))
            self.sftp.close()
Exemple #6
0
 def _efact_connect(self):
     connection = SSHClient()
     connection.load_system_host_keys()
     connection.connect(hostname=self.env["ir.config_parameter"].get_param(
         "account.invoice.efact.server", default=None),
                        port=int(self.env["ir.config_parameter"].get_param(
                            "account.invoice.efact.port", default=None)),
                        username=self.env["ir.config_parameter"].get_param(
                            "account.invoice.efact.user", default=None),
                        password=self.env["ir.config_parameter"].get_param(
                            "account.invoice.efact.password", default=None))
     sftp = connection.open_sftp()
     return connection, sftp
Exemple #7
0
    def fromPrivateKey(host,port=22,username=None,password=None,private_key=None):
        """deprecated see connect_v2 """

        #ssh -i /path/to//private_key user@localhost -p 2222
        #ssh -i /path/to//private_key username@host -p port

        if SSHClient is None:
            raise Exception("Paramiko not installed")

        src = SSHClientSource()
        print(host,port,username,password,private_key)

        pkey = None
        if private_key: # non-null, non-empty
            passphrase = "" # TODO: support passphrases
            pkey=paramiko.RSAKey.from_private_key_file(\
                private_key,passphrase)

        client = SSHClient()
        client.load_system_host_keys()
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        print("\n---\nConnect With: host=%s:%s user=%s privatekey=`%s`\n"%(\
              host,port,username,private_key))

        try:
            if pkey:
                client.connect(host,port=port,
                               username=username,
                               pkey=pkey,timeout=1.0,compress=True)
            else:
                client.connect(host,port=port,
                               username=username,password=password,
                               timeout=1.0,compress=True)
        except BadHostKeyException as e:
            sys.stderr.write("got:      %s\n"%e.key.asbytes());
            sys.stderr.write("expected: %s\n"%e.expected_key.asbytes());

            msg = "Connection error on %s:%s using privatekey=%s\n"%(\
                host,port,private_key)
            msg += "you may need to clear ~/.ssh/known_hosts "+\
                "for entries related to %s:%s\n"%(host,port);
            sys.stderr.write(msg);
            raise Exception(msg);
        src.client = client
        src.ftp = client.open_sftp()

        src.host=host
        src.port=port

        return src
Exemple #8
0
    def pull_one(self, unit, imsi):

        #Open Soracom tunnel (120 seconds)
        self.soracom.tunnel(imsi, 60)

        ip = self.soracom.get_ip()
        portnumber = self.soracom.get_port()

        #Open FTP
        print("Connecting to " + str(unit) + " at " + ip + " on port " +
              str(portnumber))
        client = SSHClient()
        client.set_missing_host_key_policy(AutoAddPolicy())
        client.connect(hostname=ip,
                       port=portnumber,
                       username='******',
                       password='******',
                       timeout=60)
        sftp = client.open_sftp()

        print("FTP Connected")

        #Timestamp
        now = datetime.now()

        #Create directory for new service logfile
        destination = 'logs/' + str(unit) + '/' + now.strftime(
            '%Y') + '/' + now.strftime('%m') + '/' + now.strftime(
                '%d') + '/' + now.strftime('%H') + "/lvm-service-log-01.txt"

        self.last_path = 'logs/' + str(unit) + '/' + now.strftime(
            '%Y') + '/' + now.strftime('%m') + '/' + now.strftime(
                '%d') + '/' + now.strftime('%H') + '/'

        self.s3_path = str(unit) + '/' + now.strftime(
            '%Y') + '/' + now.strftime('%m') + '/' + now.strftime(
                '%d') + '/' + now.strftime('%H') + '/lvm-service-log-01.txt'

        os.makedirs(self.last_path)

        #Transfer logfile
        sftp.get('mnt/data/lvm/logs/lvm-service-log-01.txt', destination)
        print("One file downloaded")

        #Close FTP connection
        sftp.close()
        print("Connection to " + str(unit) + " closed")
Exemple #9
0
    def pull_all(self, unit, imsi):
        self.soracom.tunnel(imsi, 60)

        ip = self.soracom.get_ip()
        portnumber = self.soracom.get_port()

        print("Connecting to " + str(unit) + " at " + ip + " on port " +
              str(portnumber))
        client = SSHClient()
        client.set_missing_host_key_policy(AutoAddPolicy())
        client.connect(hostname=ip,
                       port=portnumber,
                       username='******',
                       password='******',
                       timeout=60)
        sftp = client.open_sftp()

        print("FTP Connected")

        now = datetime.now()

        os.makedirs('logs/' + str(unit))
        direc = 'logs/' + str(unit) + '/'

        for log in range(10):
            print(str(log))
            if log == 0:
                destination = direc + 'lvm-service-log.txt'
                file = 'mnt/data/lvm/logs/lvm-service-log.txt'
            elif log == 10:
                destination = direc + 'lvm-service-log-10.txt'
                file = 'mnt/data/lvm/logs/lvm-service-log-10.txt'
            else:
                destination = direc + 'lvm-service-log-0' + str(log) + '.txt'
                file = 'mnt/data/lvm/logs/lvm-service-log-0' + str(
                    log) + '.txt'

            print(file)
            sftp.get(file, destination)
            print(file + " file downloaded")

        sftp.close()
        print("connection to " + str(unit) + " closed")
Exemple #10
0
def get_iodef(
        iodef_name_list,
        server='140.117.101.15',
        port=22, username='******',
        password=None,
        file_locate='/etc/iodef/{}/{}.xml'
):
    logging.debug('Generating ssh client...')
    client = SSHClient()
    logging.debug('Loading system host key...')
    client.load_system_host_keys()

    # disable following if known_hosts
    logging.debug('Setting missing host key policy to AutoAddPolicy...')
    client.set_missing_host_key_policy(AutoAddPolicy())

    logging.debug('Connecting to ssh {}@{}:{}...'.format(
        username, server, int(port),
    ))
    client.connect(
        server, port=int(port), username=username, password=password,
        timeout=5,
    )

    logging.debug('Generating sftp...')
    sftp_client = client.open_sftp()

    return_dict = dict()
    for each_iodef_name in iodef_name_list:
        try:
            quote = '/'.join(each_iodef_name.split('-')[:-1])
            file_path = file_locate.format(quote, each_iodef_name)

            logging.debug('Touching remote file {}'.format(file_path))
            with sftp_client.open(file_path) as f:
                logging.debug('Loading remote file {}'.format(file_path))
                return_dict[each_iodef_name] = f.read()
        except Exception as e:
            logging.exception('Unable to retrieve {}'.format(each_iodef_name))

    return return_dict
Exemple #11
0
def create_backup_repository(service):
    """
    Configure backup repository.

    - create filesystem folders
    - store ssh key
    - create subaccount
    """
    # Create folder and SSH key
    client = SSHClient()
    client.load_system_host_keys()
    client.connect(**settings.STORAGE_SERVER)
    ftp = client.open_sftp()
    dirname = str(uuid4())
    ftp.mkdir(dirname)
    ftp.chdir(dirname)
    ftp.mkdir(".ssh")
    ftp.chdir(".ssh")
    with ftp.open("authorized_keys", "w") as handle:
        handle.write(service.last_report.ssh_key)

    # Create account on the service
    url = "https://robot-ws.your-server.de/storagebox/{}/subaccount".format(
        settings.STORAGE_BOX
    )
    response = requests.post(
        url,
        data={
            "homedirectory": "weblate/{}".format(dirname),
            "ssh": "1",
            "external_reachability": "1",
            "comment": "Weblate backup service {}".format(service.pk),
        },
        auth=(settings.STORAGE_USER, settings.STORAGE_PASSWORD),
    )
    data = response.json()
    return "ssh://{}@{}:23/./backups".format(
        data["subaccount"]["username"], data["subaccount"]["server"]
    )
class FDSConnection():
    def __init__(self):
        self.client = SSHClient()
        self.client.load_host_keys(settings.FDS_HOST_KEY)
        self.client.connect(settings.FDS_HOST, username=settings.FDS_USER, key_filename=settings.FDS_PRIVATE_KEY,
                            port=settings.FDS_PORT)
        self.sftp = self.client.open_sftp()
        log.info("Connected to FDS")

    def get_files(self):
        log.info("Receiving files from FDS...")
        fds_data_path = os.path.join(settings.BASE_DIR, settings.FDS_DATA_PATH)

        local_files = os.listdir(fds_data_path)

        self.sftp.chdir('yellow-net-reports')
        for file in self.sftp.listdir():
            if file not in local_files and file + '.processed' not in local_files:
                log.info("Receiving {}".format(file))
                self.sftp.get(file, os.path.join(fds_data_path, file))
                # self.sftp.remove(file)
            else:
                log.debug("Skipping already present file: {}".format(file))
Exemple #13
0
class RemoteSession():
    def __init__(self, ip, username=REMOTE_USER):
        self.username = username

        self.client = SSHClient()
        self.client.set_missing_host_key_policy(IgnoreHostKeyPolicy)
        self.client.connect(ip, username=self.username)

        self.sftp = self.client.open_sftp()

    def execute(self, command):
        stdin, stdout, stderr = self.client.exec_command(command)

        rc = stdout.channel.recv_exit_status()
        out = stdout.read().decode('utf-8').splitlines()
        err = stderr.read().decode('utf-8').splitlines()

        if rc != 0 and not quiet:
            log = logging.getLogger('TPCH')

            log.error("ssh command returned %d" % rc)
            log.error("ssh -l %s %s %s" % (self.username, ip, command))
            print(command)
            for line in out:
                print(line)
            for line in err:
                print(line)
            print()

        return

    def download(self, src, dst):
        self.sftp.get(src, dst)
        return

    def close(self):
        self.client.close()
Exemple #14
0
 def efact_check_history(self):
     efact = self.env.ref("l10n_es_facturae_efact.efact_backend")
     ICP = self.env["ir.config_parameter"].sudo()
     connection = SSHClient()
     connection.load_system_host_keys()
     connection.connect(
         ICP.get_param("account.invoice.efact.server", default=None),
         port=int(ICP.get_param("account.invoice.efact.port",
                                default=None)),
         username=ICP.get_param("account.invoice.efact.user", default=None),
         password=ICP.get_param("account.invoice.efact.password",
                                default=None),
     )
     sftp = connection.open_sftp()
     path = sftp.normalize(".")
     sftp.chdir(path + statout_path)
     attrs = sftp.listdir_attr(".")
     attrs.sort(key=lambda attr: attr.st_atime)
     to_remove = []
     for attr in attrs:
         file = sftp.open(attr.filename)
         datas = file.read()
         file.close()
         update_record = efact.create_record(
             "l10n_es_facturae_efact_update",
             {
                 "edi_exchange_state": "input_received",
                 "exchange_filename": attr.filename,
             },
         )
         update_record._set_file_content(datas)
         efact.with_delay().exchange_process(update_record)
         to_remove.append(attr.filename)
     for filename in to_remove:
         sftp.remove(filename)
     sftp.close()
     connection.close()
Exemple #15
0
class sshShell:
    def __init__(self,
                 name,
                 addr,
                 usr,
                 pwd,
                 pk=None,
                 kf=None,
                 encoding='utf-8'):
        self.name, self.usr, self.pwd, self.pk, self.kf, self.encoding = name, usr, pwd, pk, kf, encoding
        self.host, port = addr.split(":")
        self.port = int(port)
        self.cacheStat = {}

    def Name(self):
        return self.name

    def __createShell(self):
        self.ssh = SSHClient()
        self.ssh.set_missing_host_key_policy(AutoAddPolicy())
        self.ssh.connect(self.host,
                         port=self.port,
                         username=self.usr,
                         password=self.pwd,
                         pkey=self.pk,
                         key_filename=self.kf,
                         banner_timeout=3.0,
                         auth_timeout=3.0)
        self.shell = self.ssh.invoke_shell()
        self.shell.settimeout(2.0)

    def __recv(self):
        try:
            return (self.shell.recv(999) or b'').decode(self.encoding)
        except Exception as e:
            return ''

    def __send(self, cmd):
        self.shell.send(cmd + '\r')

    def __close(self):
        try:
            self.shell.shutdown(2)
        except Exception as e:
            pass

        try:
            self.ssh.close()
        except Exception as e:
            pass

    def Excute(self, cmd):
        try:
            self.__createShell()
            print("shell hallo words: %s" % (self.__recv()))
            self.__send(cmd)
            result = ''
            while 1:
                rsp = self.__recv()
                if len(rsp) < 1: break
                result += rsp
            self.__close()
            return result
        except Exception as e:
            return traceback.format_exc()

    def Upload(self, files, dsDir):
        try:
            self.ssh = SSHClient()
            self.ssh.set_missing_host_key_policy(AutoAddPolicy())
            self.ssh.connect(self.host,
                             port=self.port,
                             username=self.usr,
                             password=self.pwd,
                             pkey=self.pk,
                             key_filename=self.kf,
                             banner_timeout=3.0,
                             auth_timeout=3.0)
            self.stftp = self.ssh.open_sftp()
            result = []
            for f in files:
                if os.path.isfile(f):
                    dir, fn = os.path.split(f)
                    result.append((f,
                                   self.UploadFile(self.stftp, f,
                                                   os.path.join(dsDir, fn))))
                elif os.path.isdir(f):
                    for root, dirs, fns in os.walk(f):
                        for fn in fns:
                            result.append(
                                (os.path.join(root, fn),
                                 self.UploadFile(
                                     self.stftp, os.path.join(root, fn),
                                     os.path.join(dsDir, root.replace(f, ''),
                                                  fn))))
            self.__close()
            return True, result
        except Exception as e:
            self.__close()
            return False, traceback.format_exc()

    def UploadFile(self, sftp, local, remote):
        t1 = time.time() * 1000
        try:
            # check the dir is exists
            remote = remote.replace("\\", "/")
            dir, fn = os.path.split(remote)
            self.check2CreateDir(sftp, dir)
            sftp.put(local, remote)
            return 'ok', time.time() * 1000 - t1
        except Exception as e:
            return traceback.format_exc(), time.time() * 1000 - t1

    def check2CreateDir(self, sftp, dir):
        try:
            dirSections = dir.split('/')
            if len(dir) > 0 and dir[0] == '/': dirSections[0] = '/'
            for i in range(len(dirSections)):
                pdir = '/'.join(dirSections[:i + 1])
                if stat.S_ISDIR(self.cacheStat.get(pdir, 0)): continue
                try:
                    a = sftp.stat(pdir)
                    self.cacheStat[pdir] = a.st_mode
                    if not stat.S_ISDIR(a.st_mode): return False
                except IOError as e:
                    if e.errno == 2:
                        sftp.mkdir(pdir)
                    else:
                        traceback.print_exc()
                        return False
        except Exception as e:
            traceback.print_exc()
            return False
Exemple #16
0
    df = pd.read_csv(tsv_file, sep='\t', header=0, quoting=3, dtype=str)
    for i in range(df.shape[0]):
        yield i, df.loc[i].to_dict()


for i, server in load_servers(SSH_SERVER_FILE):
    logger.info('No.%d, ip: %s' % (i, server['hostname']))
    try:
        # 1. 建立ssh连接
        ssh.connect(**server)
        logger.debug(">>> 1. 建立ssh连接")
        # 2. 安装squid
        stdin, stdout, stderr = ssh.exec_command('yum install squid -y')
        if stdout.channel.recv_exit_status() == 0:
            logger.debug(">>> 2. 安装squid完成")
            with ssh.open_sftp() as sess:
                # 3. 添加认证账号
                sess.put(AUTH_FILE, '/etc/squid/passwords')
                logger.debug(">>> 3. 上传auth文件完成")
                # 4. 更新squid.conf配置文件
                sess.put(SQUID_CONF, '/etc/squid/squid.conf')
                logger.debug(">>> 4. 上传squid.conf完成")
            # 5. systemctl start squid启动服务
            stdout = ssh.exec_command('systemctl start squid')[1]
            if stdout.channel.recv_exit_status() == 0:
                logger.debug('>>> 5. squid服务已启动')
            else:
                logger.error("squid服务启动失败: " + ''.join(stdout.readlines()))
        else:
            logger.error("安装squid失败, ip: %s" % server['hostname'])
    except Exception as err:
Exemple #17
0
class Connection(Context):
    """
    A connection to an SSH daemon, with methods for commands and file transfer.

    **Basics**

    This class inherits from Invoke's `~invoke.context.Context`, as it is a
    context within which commands, tasks etc can operate. It also encapsulates
    a Paramiko `~paramiko.client.SSHClient` instance, performing useful high
    level operations with that `~paramiko.client.SSHClient` and
    `~paramiko.channel.Channel` instances generated from it.

    .. _connect_kwargs:

    .. note::
        Many SSH specific options -- such as specifying private keys and
        passphrases, timeouts, disabling SSH agents, etc -- are handled
        directly by Paramiko and should be specified via the
        :ref:`connect_kwargs argument <connect_kwargs-arg>` of the constructor.

    **Lifecycle**

    `.Connection` has a basic "`create <__init__>`, `connect/open <open>`, `do
    work <run>`, `disconnect/close <close>`" lifecycle:

    - `Instantiation <__init__>` imprints the object with its connection
      parameters (but does **not** actually initiate the network connection).

        - An alternate constructor exists for users :ref:`upgrading piecemeal
          from Fabric 1 <from-v1>`: `from_v1`

    - Methods like `run`, `get` etc automatically trigger a call to
      `open` if the connection is not active; users may of course call `open`
      manually if desired.
    - Connections do not always need to be explicitly closed; much of the
      time, Paramiko's garbage collection hooks or Python's own shutdown
      sequence will take care of things. **However**, should you encounter edge
      cases (for example, sessions hanging on exit) it's helpful to explicitly
      close connections when you're done with them.

      This can be accomplished by manually calling `close`, or by using the
      object as a contextmanager::

        with Connection('host') as c:
            c.run('command')
            c.put('file')

    .. note::
        This class rebinds `invoke.context.Context.run` to `.local` so both
        remote and local command execution can coexist.

    **Configuration**

    Most `.Connection` parameters honor :doc:`Invoke-style configuration
    </concepts/configuration>` as well as any applicable :ref:`SSH config file
    directives <connection-ssh-config>`. For example, to end up with a
    connection to ``admin@myhost``, one could:

    - Use any built-in config mechanism, such as ``/etc/fabric.yml``,
      ``~/.fabric.json``, collection-driven configuration, env vars, etc,
      stating ``user: admin`` (or ``{"user": "******"}``, depending on config
      format.) Then ``Connection('myhost')`` would implicitly have a ``user``
      of ``admin``.
    - Use an SSH config file containing ``User admin`` within any applicable
      ``Host`` header (``Host myhost``, ``Host *``, etc.) Again,
      ``Connection('myhost')`` will default to an ``admin`` user.
    - Leverage host-parameter shorthand (described in `.Config.__init__`), i.e.
      ``Connection('admin@myhost')``.
    - Give the parameter directly: ``Connection('myhost', user='******')``.

    The same applies to agent forwarding, gateways, and so forth.

    .. versionadded:: 2.0
    """

    # NOTE: these are initialized here to hint to invoke.Config.__setattr__
    # that they should be treated as real attributes instead of config proxies.
    # (Additionally, we're doing this instead of using invoke.Config._set() so
    # we can take advantage of Sphinx's attribute-doc-comment static analysis.)
    # Once an instance is created, these values will usually be non-None
    # because they default to the default config values.
    host = None
    original_host = None
    user = None
    port = None
    ssh_config = None
    gateway = None
    forward_agent = None
    connect_timeout = None
    connect_kwargs = None
    client = None
    transport = None
    _sftp = None
    _agent_handler = None
    default_host_key_policy = AutoAddPolicy

    @classmethod
    def from_v1(cls, env, **kwargs):
        """
        Alternate constructor which uses Fabric 1's ``env`` dict for settings.

        All keyword arguments besides ``env`` are passed unmolested into the
        primary constructor.

        .. warning::
            Because your own config overrides will win over data from ``env``,
            make sure you only set values you *intend* to change from your v1
            environment!

        For details on exactly which ``env`` vars are imported and what they
        become in the new API, please see :ref:`v1-env-var-imports`.

        :param env:
            An explicit Fabric 1 ``env`` dict (technically, any
            ``fabric.utils._AttributeDict`` instance should work) to pull
            configuration from.

        .. versionadded:: 2.4
        """
        # TODO: import fabric.state.env (need good way to test it first...)
        # TODO: how to handle somebody accidentally calling this in a process
        # where 'fabric' is fabric 2, and there's no fabric 1? Probably just a
        # re-raise of ImportError??
        # Our only requirement is a non-empty host_string
        if not env.host_string:
            raise InvalidV1Env(
                "Supplied v1 env has an empty `host_string` value! Please make sure you're calling Connection.from_v1 within a connected Fabric 1 session."  # noqa
            )
        # TODO: detect collisions with kwargs & except instead of overwriting?
        # (More Zen of Python compliant, but also, effort, and also, makes it
        # harder for users to intentionally overwrite!)
        connect_kwargs = kwargs.setdefault("connect_kwargs", {})
        kwargs.setdefault("host", env.host_string)
        shorthand = derive_shorthand(env.host_string)
        # TODO: don't we need to do the below skipping for user too?
        kwargs.setdefault("user", env.user)
        # Skip port if host string seemed to have it; otherwise we hit our own
        # ambiguity clause in __init__. v1 would also have been doing this
        # anyways (host string wins over other settings).
        if not shorthand["port"]:
            # Run port through int(); v1 inexplicably has a string default...
            kwargs.setdefault("port", int(env.port))
        # key_filename defaults to None in v1, but in v2, we expect it to be
        # either unset, or set to a list. Thus, we only pull it over if it is
        # not None.
        if env.key_filename is not None:
            connect_kwargs.setdefault("key_filename", env.key_filename)
        # Obtain config values, if not given, from its own from_v1
        # NOTE: not using setdefault as we truly only want to call
        # Config.from_v1 when necessary.
        if "config" not in kwargs:
            kwargs["config"] = Config.from_v1(env)
        return cls(**kwargs)

    # TODO: should "reopening" an existing Connection object that has been
    # closed, be allowed? (See e.g. how v1 detects closed/semi-closed
    # connections & nukes them before creating a new client to the same host.)
    # TODO: push some of this into paramiko.client.Client? e.g. expand what
    # Client.exec_command does, it already allows configuring a subset of what
    # we do / will eventually do / did in 1.x. It's silly to have to do
    # .get_transport().open_session().
    def __init__(
        self,
        host,
        user=None,
        port=None,
        config=None,
        gateway=None,
        forward_agent=None,
        connect_timeout=None,
        connect_kwargs=None,
        inline_ssh_env=None,
    ):
        """
        Set up a new object representing a server connection.

        :param str host:
            the hostname (or IP address) of this connection.

            May include shorthand for the ``user`` and/or ``port`` parameters,
            of the form ``user@host``, ``host:port``, or ``user@host:port``.

            .. note::
                Due to ambiguity, IPv6 host addresses are incompatible with the
                ``host:port`` shorthand (though ``user@host`` will still work
                OK). In other words, the presence of >1 ``:`` character will
                prevent any attempt to derive a shorthand port number; use the
                explicit ``port`` parameter instead.

            .. note::
                If ``host`` matches a ``Host`` clause in loaded SSH config
                data, and that ``Host`` clause contains a ``Hostname``
                directive, the resulting `.Connection` object will behave as if
                ``host`` is equal to that ``Hostname`` value.

                In all cases, the original value of ``host`` is preserved as
                the ``original_host`` attribute.

                Thus, given SSH config like so::

                    Host myalias
                        Hostname realhostname

                a call like ``Connection(host='myalias')`` will result in an
                object whose ``host`` attribute is ``realhostname``, and whose
                ``original_host`` attribute is ``myalias``.

        :param str user:
            the login user for the remote connection. Defaults to
            ``config.user``.

        :param int port:
            the remote port. Defaults to ``config.port``.

        :param config:
            configuration settings to use when executing methods on this
            `.Connection` (e.g. default SSH port and so forth).

            Should be a `.Config` or an `invoke.config.Config`
            (which will be turned into a `.Config`).

            Default is an anonymous `.Config` object.

        :param gateway:
            An object to use as a proxy or gateway for this connection.

            This parameter accepts one of the following:

            - another `.Connection` (for a ``ProxyJump`` style gateway);
            - a shell command string (for a ``ProxyCommand`` style style
              gateway).

            Default: ``None``, meaning no gatewaying will occur (unless
            otherwise configured; if one wants to override a configured gateway
            at runtime, specify ``gateway=False``.)

            .. seealso:: :ref:`ssh-gateways`

        :param bool forward_agent:
            Whether to enable SSH agent forwarding.

            Default: ``config.forward_agent``.

        :param int connect_timeout:
            Connection timeout, in seconds.

            Default: ``config.timeouts.connect``.

        .. _connect_kwargs-arg:

        :param dict connect_kwargs:
            Keyword arguments handed verbatim to
            `SSHClient.connect <paramiko.client.SSHClient.connect>` (when
            `.open` is called).

            `.Connection` tries not to grow additional settings/kwargs of its
            own unless it is adding value of some kind; thus,
            ``connect_kwargs`` is currently the right place to hand in paramiko
            connection parameters such as ``pkey`` or ``key_filename``. For
            example::

                c = Connection(
                    host="hostname",
                    user="******",
                    connect_kwargs={
                        "key_filename": "/home/myuser/.ssh/private.key",
                    },
                )

            Default: ``config.connect_kwargs``.

        :param bool inline_ssh_env:
            Whether to send environment variables "inline" as prefixes in front
            of command strings (``export VARNAME=value && mycommand here``),
            instead of trying to submit them through the SSH protocol itself
            (which is the default behavior). This is necessary if the remote
            server has a restricted ``AcceptEnv`` setting (which is the common
            default).

            The default value is the value of the ``inline_ssh_env``
            :ref:`configuration value <default-values>` (which itself defaults
            to ``False``).

            .. warning::
                This functionality does **not** currently perform any shell
                escaping on your behalf! Be careful when using nontrivial
                values, and note that you can put in your own quoting,
                backslashing etc if desired.

                Consider using a different approach (such as actual
                remote shell scripts) if you run into too many issues here.

            .. note::
                When serializing into prefixed ``FOO=bar`` format, we apply the
                builtin `sorted` function to the env dictionary's keys, to
                remove what would otherwise be ambiguous/arbitrary ordering.

            .. note::
                This setting has no bearing on *local* shell commands; it only
                affects remote commands, and thus, methods like `.run` and
                `.sudo`.

        :raises ValueError:
            if user or port values are given via both ``host`` shorthand *and*
            their own arguments. (We `refuse the temptation to guess`_).

        .. _refuse the temptation to guess:
            http://zen-of-python.info/
            in-the-face-of-ambiguity-refuse-the-temptation-to-guess.html#12

        .. versionchanged:: 2.3
            Added the ``inline_ssh_env`` parameter.
        """
        # NOTE: parent __init__ sets self._config; for now we simply overwrite
        # that below. If it's somehow problematic we would want to break parent
        # __init__ up in a manner that is more cleanly overrideable.
        super(Connection, self).__init__(config=config)

        #: The .Config object referenced when handling default values (for e.g.
        #: user or port, when not explicitly given) or deciding how to behave.
        if config is None:
            config = Config()
        # Handle 'vanilla' Invoke config objects, which need cloning 'into' one
        # of our own Configs (which grants the new defaults, etc, while not
        # squashing them if the Invoke-level config already accounted for them)
        elif not isinstance(config, Config):
            config = config.clone(into=Config)
        self._set(_config=config)
        # TODO: when/how to run load_files, merge, load_shell_env, etc?
        # TODO: i.e. what is the lib use case here (and honestly in invoke too)

        shorthand = self.derive_shorthand(host)
        host = shorthand["host"]
        err = "You supplied the {} via both shorthand and kwarg! Please pick one."  # noqa
        if shorthand["user"] is not None:
            if user is not None:
                raise ValueError(err.format("user"))
            user = shorthand["user"]
        if shorthand["port"] is not None:
            if port is not None:
                raise ValueError(err.format("port"))
            port = shorthand["port"]

        # NOTE: we load SSH config data as early as possible as it has
        # potential to affect nearly every other attribute.
        #: The per-host SSH config data, if any. (See :ref:`ssh-config`.)
        self.ssh_config = self.config.base_ssh_config.lookup(host)

        self.original_host = host
        #: The hostname of the target server.
        self.host = host
        if "hostname" in self.ssh_config:
            # TODO: log that this occurred?
            self.host = self.ssh_config["hostname"]

        #: The username this connection will use to connect to the remote end.
        self.user = user or self.ssh_config.get("user", self.config.user)
        # TODO: is it _ever_ possible to give an empty user value (e.g.
        # user='')? E.g. do some SSH server specs allow for that?

        #: The network port to connect on.
        self.port = port or int(self.ssh_config.get("port", self.config.port))

        # Gateway/proxy/bastion/jump setting: non-None values - string,
        # Connection, even eg False - get set directly; None triggers seek in
        # config/ssh_config
        #: The gateway `.Connection` or ``ProxyCommand`` string to be used,
        #: if any.
        self.gateway = gateway if gateway is not None else self.get_gateway()
        # NOTE: we use string above, vs ProxyCommand obj, to avoid spinning up
        # the ProxyCommand subprocess at init time, vs open() time.
        # TODO: make paramiko.proxy.ProxyCommand lazy instead?

        if forward_agent is None:
            # Default to config...
            forward_agent = self.config.forward_agent
            # But if ssh_config is present, it wins
            if "forwardagent" in self.ssh_config:
                # TODO: SSHConfig really, seriously needs some love here, god
                map_ = {"yes": True, "no": False}
                forward_agent = map_[self.ssh_config["forwardagent"]]
        #: Whether agent forwarding is enabled.
        self.forward_agent = forward_agent

        if connect_timeout is None:
            connect_timeout = self.ssh_config.get("connecttimeout",
                                                  self.config.timeouts.connect)
        if connect_timeout is not None:
            connect_timeout = int(connect_timeout)
        #: Connection timeout
        self.connect_timeout = connect_timeout

        #: Keyword arguments given to `paramiko.client.SSHClient.connect` when
        #: `open` is called.
        self.connect_kwargs = self.resolve_connect_kwargs(connect_kwargs)

        #: The `paramiko.client.SSHClient` instance this connection wraps.
        self.client = SSHClient()
        self.setup_ssh_client()

        #: A convenience handle onto the return value of
        #: ``self.client.get_transport()``.
        self.transport = None

        if inline_ssh_env is None:
            inline_ssh_env = self.config.inline_ssh_env
        #: Whether to construct remote command lines with env vars prefixed
        #: inline.
        self.inline_ssh_env = inline_ssh_env

    def setup_ssh_client(self):
        if self.default_host_key_policy is not None:
            logging.debug("host key policy: %s", self.default_host_key_policy)
            self.client.set_missing_host_key_policy(
                self.default_host_key_policy())
        known_hosts = self.ssh_config.get("UserKnownHostsFile".lower(),
                                          "~/.ssh/known_hosts")
        logging.debug("loading host keys from %s", known_hosts)
        # multiple keys, seperated by whitespace, can be provided
        for filename in [os.path.expanduser(f) for f in known_hosts.split()]:
            if os.path.exists(filename):
                self.client.load_host_keys(filename)

    def resolve_connect_kwargs(self, connect_kwargs):
        # Grab connect_kwargs from config if not explicitly given.
        if connect_kwargs is None:
            # TODO: is it better to pre-empt conflicts w/ manually-handled
            # connect() kwargs (hostname, username, etc) here or in open()?
            # We're doing open() for now in case e.g. someone manually modifies
            # .connect_kwargs attributewise, but otherwise it feels better to
            # do it early instead of late.
            connect_kwargs = self.config.connect_kwargs
        # Special case: key_filename gets merged instead of overridden.
        # TODO: probably want some sorta smart merging generally, special cases
        # are bad.
        elif "key_filename" in self.config.connect_kwargs:
            kwarg_val = connect_kwargs.get("key_filename", [])
            conf_val = self.config.connect_kwargs["key_filename"]
            # Config value comes before kwarg value (because it may contain
            # CLI flag value.)
            connect_kwargs["key_filename"] = conf_val + kwarg_val

        # SSH config identityfile values come last in the key_filename
        # 'hierarchy'.
        if "identityfile" in self.ssh_config:
            connect_kwargs.setdefault("key_filename", [])
            connect_kwargs["key_filename"].extend(
                self.ssh_config["identityfile"])

        return connect_kwargs

    def get_gateway(self):
        # SSH config wins over Invoke-style config
        if "proxyjump" in self.ssh_config:
            # Reverse hop1,hop2,hop3 style ProxyJump directive so we start
            # with the final (itself non-gatewayed) hop and work up to
            # the front (actual, supplied as our own gateway) hop
            hops = reversed(self.ssh_config["proxyjump"].split(","))
            prev_gw = None
            for hop in hops:
                # Short-circuit if we appear to be our own proxy, which would
                # be a RecursionError. Implies SSH config wildcards.
                # TODO: in an ideal world we'd check user/port too in case they
                # differ, but...seriously? They can file a PR with those extra
                # half dozen test cases in play, E_NOTIME
                if self.derive_shorthand(hop)["host"] == self.host:
                    return None
                # Happily, ProxyJump uses identical format to our host
                # shorthand...
                kwargs = dict(config=self.config.clone())
                if prev_gw is not None:
                    kwargs["gateway"] = prev_gw
                cxn = Connection(hop, **kwargs)
                prev_gw = cxn
            return prev_gw
        elif "proxycommand" in self.ssh_config:
            # Just a string, which we interpret as a proxy command..
            return self.ssh_config["proxycommand"]
        # Fallback: config value (may be None).
        return self.config.gateway

    def __repr__(self):
        # Host comes first as it's the most common differentiator by far
        bits = [("host", self.host)]
        # TODO: maybe always show user regardless? Explicit is good...
        if self.user != self.config.user:
            bits.append(("user", self.user))
        # TODO: harder to make case for 'always show port'; maybe if it's
        # non-22 (even if config has overridden the local default)?
        if self.port != self.config.port:
            bits.append(("port", self.port))
        # NOTE: sometimes self.gateway may be eg False if someone wants to
        # explicitly override a configured non-None value (as otherwise it's
        # impossible for __init__ to tell if a None means "nothing given" or
        # "seriously please no gatewaying". So, this must always be a vanilla
        # truth test and not eg "is not None".
        if self.gateway:
            # Displaying type because gw params would probs be too verbose
            val = "proxyjump"
            if isinstance(self.gateway, string_types):
                val = "proxycommand"
            bits.append(("gw", val))
        return "<Connection {}>".format(" ".join("{}={}".format(*x)
                                                 for x in bits))

    def _identity(self):
        # TODO: consider including gateway and maybe even other init kwargs?
        # Whether two cxns w/ same user/host/port but different
        # gateway/keys/etc, should be considered "the same", is unclear.
        return (self.host, self.user, self.port)

    def __eq__(self, other):
        if not isinstance(other, Connection):
            return False
        return self._identity() == other._identity()

    def __lt__(self, other):
        return self._identity() < other._identity()

    def __hash__(self):
        # NOTE: this departs from Context/DataProxy, which is not usefully
        # hashable.
        return hash(self._identity())

    def derive_shorthand(self, host_string):
        # NOTE: used to be defined inline; preserving API call for both
        # backwards compatibility and because it seems plausible we may want to
        # modify behavior later, using eg config or other attributes.
        return derive_shorthand(host_string)

    @property
    def is_connected(self):
        """
        Whether or not this connection is actually open.

        .. versionadded:: 2.0
        """
        return self.transport.active if self.transport else False

    def open(self):
        """
        Initiate an SSH connection to the host/port this object is bound to.

        This may include activating the configured gateway connection, if one
        is set.

        Also saves a handle to the now-set Transport object for easier access.

        Various connect-time settings (and/or their corresponding :ref:`SSH
        config options <ssh-config>`) are utilized here in the call to
        `SSHClient.connect <paramiko.client.SSHClient.connect>`. (For details,
        see :doc:`the configuration docs </concepts/configuration>`.)

        .. versionadded:: 2.0
        """
        # Short-circuit
        if self.is_connected:
            return
        err = "Refusing to be ambiguous: connect() kwarg '{}' was given both via regular arg and via connect_kwargs!"  # noqa
        # These may not be given, period
        for key in """
            hostname
            port
            username
        """.split():
            if key in self.connect_kwargs:
                raise ValueError(err.format(key))
        # These may be given one way or the other, but not both
        if ("timeout" in self.connect_kwargs
                and self.connect_timeout is not None):
            raise ValueError(err.format("timeout"))
        # No conflicts -> merge 'em together
        kwargs = dict(
            self.connect_kwargs,
            username=self.user,
            hostname=self.host,
            port=self.port,
        )
        if self.gateway:
            kwargs["sock"] = self.open_gateway()
        if self.connect_timeout:
            kwargs["timeout"] = self.connect_timeout
        # Strip out empty defaults for less noisy debugging
        if "key_filename" in kwargs and not kwargs["key_filename"]:
            del kwargs["key_filename"]
        # Actually connect!
        self.client.connect(**kwargs)
        self.transport = self.client.get_transport()

    def open_gateway(self):
        """
        Obtain a socket-like object from `gateway`.

        :returns:
            A ``direct-tcpip`` `paramiko.channel.Channel`, if `gateway` was a
            `.Connection`; or a `~paramiko.proxy.ProxyCommand`, if `gateway`
            was a string.

        .. versionadded:: 2.0
        """
        # ProxyCommand is faster to set up, so do it first.
        if isinstance(self.gateway, string_types):
            # Leverage a dummy SSHConfig to ensure %h/%p/etc are parsed.
            # TODO: use real SSH config once loading one properly is
            # implemented.
            ssh_conf = SSHConfig()
            dummy = "Host {}\n    ProxyCommand {}"
            ssh_conf.parse(StringIO(dummy.format(self.host, self.gateway)))
            return ProxyCommand(ssh_conf.lookup(self.host)["proxycommand"])
        # Handle inner-Connection gateway type here.
        # TODO: logging
        self.gateway.open()
        # TODO: expose the opened channel itself as an attribute? (another
        # possible argument for separating the two gateway types...) e.g. if
        # someone wanted to piggyback on it for other same-interpreter socket
        # needs...
        # TODO: and the inverse? allow users to supply their own socket/like
        # object they got via $WHEREEVER?
        # TODO: how best to expose timeout param? reuse general connection
        # timeout from config?
        return self.gateway.transport.open_channel(
            kind="direct-tcpip",
            dest_addr=(self.host, int(self.port)),
            # NOTE: src_addr needs to be 'empty but not None' values to
            # correctly encode into a network message. Theoretically Paramiko
            # could auto-interpret None sometime & save us the trouble.
            src_addr=("", 0),
        )

    def close(self):
        """
        Terminate the network connection to the remote end, if open.

        If no connection is open, this method does nothing.

        .. versionadded:: 2.0
        """
        if self.is_connected:
            self.client.close()
            if self.forward_agent and self._agent_handler is not None:
                self._agent_handler.close()

    def __enter__(self):
        return self

    def __exit__(self, *exc):
        self.close()

    @opens
    def create_session(self):
        channel = self.transport.open_session()
        if self.forward_agent:
            self._agent_handler = AgentRequestHandler(channel)
        return channel

    def _remote_runner(self):
        return self.config.runners.remote(self, inline_env=self.inline_ssh_env)

    @opens
    def run(self, command, **kwargs):
        """
        Execute a shell command on the remote end of this connection.

        This method wraps an SSH-capable implementation of
        `invoke.runners.Runner.run`; see its documentation for details.

        .. warning::
            There are a few spots where Fabric departs from Invoke's default
            settings/behaviors; they are documented under
            `.Config.global_defaults`.

        .. versionadded:: 2.0
        """
        return self._run(self._remote_runner(), command, **kwargs)

    @opens
    def sudo(self, command, **kwargs):
        """
        Execute a shell command, via ``sudo``, on the remote end.

        This method is identical to `invoke.context.Context.sudo` in every way,
        except in that -- like `run` -- it honors per-host/per-connection
        configuration overrides in addition to the generic/global ones. Thus,
        for example, per-host sudo passwords may be configured.

        .. versionadded:: 2.0
        """
        return self._sudo(self._remote_runner(), command, **kwargs)

    def local(self, *args, **kwargs):
        """
        Execute a shell command on the local system.

        This method is effectively a wrapper of `invoke.run`; see its docs for
        details and call signature.

        .. versionadded:: 2.0
        """
        # Superclass run() uses runners.local, so we can literally just call it
        # straight.
        return super(Connection, self).run(*args, **kwargs)

    @opens
    def sftp(self):
        """
        Return a `~paramiko.sftp_client.SFTPClient` object.

        If called more than one time, memoizes the first result; thus, any
        given `.Connection` instance will only ever have a single SFTP client,
        and state (such as that managed by
        `~paramiko.sftp_client.SFTPClient.chdir`) will be preserved.

        .. versionadded:: 2.0
        """
        if self._sftp is None:
            self._sftp = self.client.open_sftp()
        return self._sftp

    def get(self, *args, **kwargs):
        """
        Get a remote file to the local filesystem or file-like object.

        Simply a wrapper for `.Transfer.get`. Please see its documentation for
        all details.

        .. versionadded:: 2.0
        """
        return Transfer(self).get(*args, **kwargs)

    def put(self, *args, **kwargs):
        """
        Put a remote file (or file-like object) to the remote filesystem.

        Simply a wrapper for `.Transfer.put`. Please see its documentation for
        all details.

        .. versionadded:: 2.0
        """
        return Transfer(self).put(*args, **kwargs)

    # TODO: yield the socket for advanced users? Other advanced use cases
    # (perhaps factor out socket creation itself)?
    # TODO: probably push some of this down into Paramiko
    @contextmanager
    @opens
    def forward_local(
        self,
        local_port,
        remote_port=None,
        remote_host="localhost",
        local_host="localhost",
    ):
        """
        Open a tunnel connecting ``local_port`` to the server's environment.

        For example, say you want to connect to a remote PostgreSQL database
        which is locked down and only accessible via the system it's running
        on. You have SSH access to this server, so you can temporarily make
        port 5432 on your local system act like port 5432 on the server::

            import psycopg2
            from fabric import Connection

            with Connection('my-db-server').forward_local(5432):
                db = psycopg2.connect(
                    host='localhost', port=5432, database='mydb'
                )
                # Do things with 'db' here

        This method is analogous to using the ``-L`` option of OpenSSH's
        ``ssh`` program.

        :param int local_port: The local port number on which to listen.

        :param int remote_port:
            The remote port number. Defaults to the same value as
            ``local_port``.

        :param str local_host:
            The local hostname/interface on which to listen. Default:
            ``localhost``.

        :param str remote_host:
            The remote hostname serving the forwarded remote port. Default:
            ``localhost`` (i.e., the host this `.Connection` is connected to.)

        :returns:
            Nothing; this method is only useful as a context manager affecting
            local operating system state.

        .. versionadded:: 2.0
        """
        if not remote_port:
            remote_port = local_port

        # TunnelManager does all of the work, sitting in the background (so we
        # can yield) and spawning threads every time somebody connects to our
        # local port.
        finished = Event()
        manager = TunnelManager(
            local_port=local_port,
            local_host=local_host,
            remote_port=remote_port,
            remote_host=remote_host,
            # TODO: not a huge fan of handing in our transport, but...?
            transport=self.transport,
            finished=finished,
        )
        manager.start()

        # Return control to caller now that things ought to be operational
        try:
            yield
        # Teardown once user exits block
        finally:
            # Signal to manager that it should close all open tunnels
            finished.set()
            # Then wait for it to do so
            manager.join()
            # Raise threading errors from within the manager, which would be
            # one of:
            # - an inner ThreadException, which was created by the manager on
            # behalf of its Tunnels; this gets directly raised.
            # - some other exception, which would thus have occurred in the
            # manager itself; we wrap this in a new ThreadException.
            # NOTE: in these cases, some of the metadata tracking in
            # ExceptionHandlingThread/ExceptionWrapper/ThreadException (which
            # is useful when dealing with multiple nearly-identical sibling IO
            # threads) is superfluous, but it doesn't feel worth breaking
            # things up further; we just ignore it for now.
            wrapper = manager.exception()
            if wrapper is not None:
                if wrapper.type is ThreadException:
                    raise wrapper.value
                else:
                    raise ThreadException([wrapper])

            # TODO: cancel port forward on transport? Does that even make sense
            # here (where we used direct-tcpip) vs the opposite method (which
            # is what uses forward-tcpip)?

    # TODO: probably push some of this down into Paramiko
    @contextmanager
    @opens
    def forward_remote(
        self,
        remote_port,
        local_port=None,
        remote_host="127.0.0.1",
        local_host="localhost",
    ):
        """
        Open a tunnel connecting ``remote_port`` to the local environment.

        For example, say you're running a daemon in development mode on your
        workstation at port 8080, and want to funnel traffic to it from a
        production or staging environment.

        In most situations this isn't possible as your office/home network
        probably blocks inbound traffic. But you have SSH access to this
        server, so you can temporarily make port 8080 on that server act like
        port 8080 on your workstation::

            from fabric import Connection

            c = Connection('my-remote-server')
            with c.forward_remote(8080):
                c.run("remote-data-writer --port 8080")
                # Assuming remote-data-writer runs until interrupted, this will
                # stay open until you Ctrl-C...

        This method is analogous to using the ``-R`` option of OpenSSH's
        ``ssh`` program.

        :param int remote_port: The remote port number on which to listen.

        :param int local_port:
            The local port number. Defaults to the same value as
            ``remote_port``.

        :param str local_host:
            The local hostname/interface the forwarded connection talks to.
            Default: ``localhost``.

        :param str remote_host:
            The remote interface address to listen on when forwarding
            connections. Default: ``127.0.0.1`` (i.e. only listen on the remote
            localhost).

        :returns:
            Nothing; this method is only useful as a context manager affecting
            local operating system state.

        .. versionadded:: 2.0
        """
        if not local_port:
            local_port = remote_port
        # Callback executes on each connection to the remote port and is given
        # a Channel hooked up to said port. (We don't actually care about the
        # source/dest host/port pairs at all; only whether the channel has data
        # to read and suchlike.)
        # We then pair that channel with a new 'outbound' socket connection to
        # the local host/port being forwarded, in a new Tunnel.
        # That Tunnel is then added to a shared data structure so we can track
        # & close them during shutdown.
        #
        # TODO: this approach is less than ideal because we have to share state
        # between ourselves & the callback handed into the transport's own
        # thread handling (which is roughly analogous to our self-controlled
        # TunnelManager for local forwarding). See if we can use more of
        # Paramiko's API (or improve it and then do so) so that isn't
        # necessary.
        tunnels = []

        def callback(channel, src_addr_tup, dst_addr_tup):
            sock = socket.socket()
            # TODO: handle connection failure such that channel, etc get closed
            sock.connect((local_host, local_port))
            # TODO: we don't actually need to generate the Events at our level,
            # do we? Just let Tunnel.__init__ do it; all we do is "press its
            # button" on shutdown...
            tunnel = Tunnel(channel=channel, sock=sock, finished=Event())
            tunnel.start()
            # Communication between ourselves & the Paramiko handling subthread
            tunnels.append(tunnel)

        # Ask Paramiko (really, the remote sshd) to call our callback whenever
        # connections are established on the remote iface/port.
        # transport.request_port_forward(remote_host, remote_port, callback)
        try:
            self.transport.request_port_forward(address=remote_host,
                                                port=remote_port,
                                                handler=callback)
            yield
        finally:
            # TODO: see above re: lack of a TunnelManager
            # TODO: and/or also refactor with TunnelManager re: shutdown logic.
            # E.g. maybe have a non-thread TunnelManager-alike with a method
            # that acts as the callback? At least then there's a tiny bit more
            # encapsulation...meh.
            for tunnel in tunnels:
                tunnel.finished.set()
                tunnel.join()
            self.transport.cancel_port_forward(address=remote_host,
                                               port=remote_port)
Exemple #18
0
class WorkerInterface(object):
    """An interface to perform tasks on the DAQ worker nodes.

    This is used perform tasks on the computers running the data router and the ECC server. This includes things
    like cleaning up the data files at the end of each run.

    The connection is made using SSH, and the SSH config file at ``config_path`` is honored in making the connection.
    Additionally, the server *must* accept connections authenticated using a public key, and this public key must
    be available in your ``.ssh`` directory.

    Parameters
    ----------
    hostname : str
        The hostname to connect to.
    port : int, optional
        The port that the SSH server is listening on. The default is 22.
    username : str, optional
        The username to use. If it isn't provided, a username will be read from the SSH config file. If no username
        is listed there, the name of the user running the code will be used.
    config_path : str, optional
        The path to the SSH config file. The default is ``~/.ssh/config``.

    """
    def __init__(self, hostname, port=22, username=None, config_path=None):
        self.hostname = hostname
        self.client = SSHClient()

        self.client.load_system_host_keys()
        self.client.set_missing_host_key_policy(AutoAddPolicy())

        if config_path is None:
            config_path = os.path.join(os.path.expanduser('~'), '.ssh', 'config')
        self.config = SSHConfig()
        with open(config_path) as config_file:
            self.config.parse(config_file)

        if hostname in self.config.get_hostnames():
            host_cfg = self.config.lookup(hostname)
            full_hostname = host_cfg.get('hostname', hostname)
            if username is None:
                username = host_cfg.get('user', None)  # If none, it will try the user running the server.
        else:
            full_hostname = hostname

        self.client.connect(full_hostname, port, username=username)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.client.close()

    def find_data_router(self):
        """Find the working directory of the data router process.

        The directory is found using ``lsof``, which must be available on the remote system.

        Returns
        -------
        str
            The directory where the data router is running, and therefore writing data.

        Raises
        ------
        RuntimeError
            If ``lsof`` finds something strange instead of a process called ``dataRouter``.

        """
        stdin, stdout, stderr = self.client.exec_command('lsof -a -d cwd -c dataRouter -Fcn')
        for line in stdout:
            if line[0] == 'c' and not re.match('cdataRouter', line):
                raise RuntimeError("lsof found {} instead of dataRouter".format(line[1:].strip()))
            elif line[0] == 'n':
                return line[1:].strip()
        else:
            raise RuntimeError("lsof didn't find dataRouter")

    def get_graw_list(self):
        """Get a list of GRAW files in the data router's working directory.

        Returns
        -------
        list[str]
            A list of the full paths to the GRAW files.

        """
        data_dir = self.find_data_router()

        with self.client.open_sftp() as sftp:
            full_list = sftp.listdir(data_dir)

        graws = filter(lambda s: re.match(r'.*\.graw$', s), full_list)
        full_graw_paths = (os.path.join(data_dir, filename) for filename in graws)

        return list(full_graw_paths)

    def working_dir_is_clean(self):
        """Check if there are GRAW files in the data router's working directory.

        Returns
        -------
        bool
            True if there are files in the working directory, False otherwise.
        """
        return len(self.get_graw_list()) == 0

    def _check_process_status(self, process_name):
        """Checks if the given process is running.

        Parameters
        ----------
        process_name : str
            The name of the process to look for

        Returns
        -------
        bool
            True if the process is running.
        """

        _, stdout, _ = self.client.exec_command('ps -e')

        for line in stdout:
            if re.search(process_name, line):
                return True
        else:
            return False

    def check_ecc_server_status(self):
        """Checks if the ECC server is running.

        Returns
        -------
        bool
            True if ``getEccSoapServer`` is running.
        """
        return self._check_process_status(r'getEccSoapServer')

    def check_data_router_status(self):
        """Checks if the data router is running.

        Returns
        -------
        bool
            True if ``dataRouter`` is running.
        """
        return self._check_process_status(r'dataRouter')

    def build_run_dir_path(self, experiment_name, run_number):
        """Get the path to the directory for a given run.

        This returns a path of the format ``experiment_name/run_name`` under the directory where the data router
        is running. The ``run_name``, in this case, has the format ``run_NNNN``.

        Parameters
        ----------
        experiment_name : str
            The name of the experiment directory.
        run_number : int
            The run number.

        Returns
        -------
        run_dir : str
            The full path to the run directory. *This should be escaped before passing it to a shell command.*

        """
        pwd = self.find_data_router()
        run_name = 'run_{:04d}'.format(run_number)  # run_0001, run_0002, etc.
        run_dir = os.path.join(pwd, experiment_name, run_name)
        return run_dir

    def organize_files(self, experiment_name, run_number):
        """Organize the GRAW files at the end of a run.

        This will get a list of the files written in the working directory of the data router and move them to
        the directory ``./experiment_name/run_name``, which will be created if necessary. For example, if
        the ``experiment_name`` is "test" and the ``run_number`` is 4, the files will be placed in ``./test/run_0004``.

        Parameters
        ----------
        experiment_name : str
            A name for the experiment directory.
        run_number : int
            The current run number.

        """
        run_dir = self.build_run_dir_path(experiment_name, run_number)

        graws = self.get_graw_list()

        with self.client.open_sftp() as sftp:
            mkdir_recursive(sftp, run_dir)
            for srcpath in graws:
                _, srcfile = os.path.split(srcpath)
                destpath = os.path.join(run_dir, srcfile)
                sftp.rename(srcpath, destpath)

    def backup_config_files(self, experiment_name, run_number, file_paths, backup_root):
        """Makes a copy of the config files on the remote computer.

        The files are copied to a subdirectory ``experiment_name/run_name`` of ``backup_root``.

        Parameters
        ----------
        experiment_name : str
            The name of the experiment.
        run_number : int
            The run number.
        file_paths : iterable of str
            The *full* paths to the config files.
        backup_root : str
            Where the backups should be written.

        """
        run_name = 'run_{:04d}'.format(run_number)
        backup_dest = os.path.join(backup_root, experiment_name, run_name)

        with self.client.open_sftp() as sftp:
            mkdir_recursive(sftp, backup_dest)
            for source_path in file_paths:
                dest_path = os.path.join(backup_dest, os.path.basename(source_path))
                with sftp.open(source_path, 'r') as src, sftp.open(dest_path, 'w') as dest:
                    buffer = src.read()
                    dest.write(buffer)

    def tail_file(self, path, num_lines=50):
        """Retrieve the tail of a text file on the remote host.

        Note that this assumes the file is ASCII-encoded plain text.

        Parameters
        ----------
        path : str
            Path to the file.
        num_lines : int
            The number of lines to include.

        Returns
        -------
        str
            The tail of the file's contents.
        """
        # Based on https://gist.github.com/volker48/3437288
        with self.client.open_sftp() as sftp:
            with sftp.open(path, 'r') as f:
                f.seek(-1, SFTPFile.SEEK_END)
                lines = 0
                while lines < num_lines and f.tell() > 0:
                    char = f.read(1)
                    if char == b'\n':
                        lines += 1
                        if lines == num_lines:
                            break
                    f.seek(-2, SFTPFile.SEEK_CUR)

                return f.read().decode('ascii')
Exemple #19
0
class SFTP(Communicator):
    ssh_client = None

    @property
    def client(self):

        if self.ssh_client and self.ssh_client.get_transport().is_active():
            return self._active_client

        self._active_client = self._client()

        return self._active_client

    def _client(self):

        logger.info(u'Conecting through SSH to the server (%s:%s)', self.host,
                    self.port)

        try:
            self.ssh_client = SSHClient()
            self.ssh_client.set_missing_host_key_policy(
                paramiko.AutoAddPolicy())
            self.ssh_client.connect(self.host,
                                    username=self.user,
                                    password=self.password,
                                    compress=True)
        except ssh_exception.AuthenticationException:
            logger.error(
                u'Fail while connecting through SSH. Check your creadentials.')
            return None
        except ssh_exception.NoValidConnectionsError:
            logger.error(
                u'Fail while connecting through SSH. Check your credentials or the server availability.'
            )
            return None
        else:
            return self.ssh_client.open_sftp()

    def mkdir(self, path):

        logger.info(u'Creating directory (%s)', path)

        try:
            self.client.mkdir(path)
            logger.debug(u'Directory has being created (%s)', path)
        except IOError as e:
            try:
                self.client.stat(path)
                logger.warning(u'Directory already exists (%s)', path)
            except IOError as e:
                logger.error(u'Fail while creating directory (%s): %s', path,
                             e.strerror)
                raise (e)

    def chdir(self, path):

        logger.info(u'Changing to directory (%s)', path)

        try:
            self.client.chdir(path)
        except IOError as e:
            logger.error(u'Fail while accessing directory (%s): %s', path,
                         e.strerror)
            raise (e)

    def put(self, from_fl, to_fl):

        logger.info(u'Copying file from (%s) to (%s)', from_fl, to_fl)

        try:
            self.client.put(from_fl, to_fl)
            logger.debug(u'File has being copied (%s)', to_fl)
        except OSError as e:
            logger.error(u'Fail while copying file (%s), file not found',
                         to_fl)
        except IOError as e:
            logger.error(u'Fail while copying file (%s): %s', to_fl,
                         e.strerror)
		line=f.readline()
		if line=="":
			break
		line=line.strip()
		line=line.split(" ")
		if(line[0]=="RUN"):
			writeLog("=============================================")
			cs=client.get_transport().open_session()
			stdout=cs.makefile()
			stderr=cs.makefile_stderr()
			command=" ".join(line[1:])
			writeLog("Running Command "+command)
			cs.exec_command(command)
			out=stdout.read()
			err=stderr.read()
			rc=cs.recv_exit_status()
			writeLog("Exit code "+str(rc))
			writeLog("====== STDOUT =====\n"+out)
			writeLog("====== STDERR =====\n"+err)
			cs.close()
			writeLog("=============================================")
		elif(line[0]=="COPY"):
			writeLog("=============================================")
			sftp=client.open_sftp()
			source=line[1].strip()
			destination=line[2].strip()
			writeLog("COPY {0} to {1}".format(source,destination))
			sftp.put(source,destination)
			writeLog("=============================================")

Exemple #21
0
class RemoteClient(object):
    """Remote Client is a wrapper over SSHClient with utility functions.

    Args:
        host (string): The hostname of the server to connect. It can be an IP
            address of the server also.
        user (string, optional): The user to connect to the remote server. It
            defaults to root

    Attributes:
        host (string): The hostname passed in as a the argument
        user (string): The user to connect as to the remote server
        client (:class:`paramiko.client.SSHClient`): The SSHClient object used
            for all the communications with the remote server.
        sftpclient (:class:`paramiko.sftp_client.SFTPClient`): The SFTP object
            for all the file transfer operations over the SSH.
    """
    def __init__(self, host, user='******'):
        self.host = host
        self.user = user
        self.client = SSHClient()
        self.sftpclient = None
        self.client.set_missing_host_key_policy(AutoAddPolicy())
        self.client.load_system_host_keys()

    def startup(self):
        """Function that starts SSH connection and makes client available for
        carrying out the functions.
        """
        self.client.connect(self.host, port=22, username=self.user)
        self.sftpclient = self.client.open_sftp()

    def download(self, remote, local):
        """Downloads a file from remote server to the local system.

        Args:
            remote (string): location of the file in remote server
            local (string): path where the file should be saved
        """
        if not self.sftpclient:
            raise ClientNotSetupException(
                'Cannot download file. Client not initialized')

        try:
            self.sftpclient.get(remote, local)
        except OSError:
            return "Error: Local file %s doesn't exist." % local
        except IOError:
            return "Error: Remote location %s doesn't exist." % remote
        finally:
            return "Download successful. File at: {0}".format(local)

    def upload(self, local, remote):
        """Uploads the file from local location to remote server.

        Args:
            local (string): path of the local file to upload
            remote (string): location on remote server to put the file
        """
        if not self.sftpclient:
            raise ClientNotSetupException(
                'Cannot upload file. Client not initialized')

        try:
            self.sftpclient.put(local, remote)
        except OSError:
            return "Error: Local file %s doesn't exist." % local
        except IOError:
            return "Error: Remote location %s doesn't exist." % remote
        finally:
            return "Upload successful. File at: {0}".format(remote)

    def exists(self, filepath):
        """Returns whether a file exists or not in the remote server.

        Args:
            filepath (string): path to the file to check for existance

        Returns:
            True if it exists, False if it doesn't
        """
        if not self.client:
            raise ClientNotSetupException(
                'Cannot run procedure. Client not initialized')
        cin, cout, cerr = self.client.exec_command('stat {0}'.format(filepath))
        if len(cout.read()) > 5:
            return True
        elif len(cerr.read()) > 5:
            return False

    def run(self, command):
        """Run a command in the remote server.

        Args:
            command (string): the command to be run on the remote server

        Returns:
            tuple of three strings containing text from stdin, stdout an stderr
        """
        if not self.client:
            raise ClientNotSetupException(
                'Cannot run procedure. Client not initialized')

        buffers = self.client.exec_command(command)
        output = []
        for buf in buffers:
            try:
                output.append(buf.read())
            except IOError:
                output.append('')

        return tuple(output)

    def close(self):
        """Close the SSH Connection
        """
        self.client.close()
Exemple #22
0
    def connect_v2(host,port,username,password=None,private_key=None,config_path=None):

        """
        Note:
            it may be possible to escalate root priveleges
            stdin,stdout,stderr = ssh.exec_command("sudo su; whoami", get_pty=True)
            various strategies exist for entering the password.
        """
        # build a configuration from the options
        cfg = {"hostname":host,
               "port":port,
               "username":username,
               "timeout":10.0,"compress":True,
               "allow_agent":False,
               "look_for_keys":False}
        if password is not None:
            cfg["password"] = password

        user_config = getSSHConfig(config_path,host)

        # copy the settings from the userconfig, overwriting the cfg.
        for k in ('hostname', 'username', 'port'):
            if k in user_config:
                cfg[k] = user_config[k]

        # proxy command allows for two factor authentication
        if 'proxycommand' in user_config:
            cfg['sock'] = paramiko.ProxyCommand(user_config['proxycommand'])
            cfg['timeout'] = 45.0; # give extra time for two factor

        print(cfg)

        cfg['pkey'] = None
        if private_key: # non-null, non-empty
            passphrase = "" # TODO: support passphrases
            cfg['pkey']=paramiko.RSAKey.from_private_key_file(\
                        private_key,passphrase)

        client = SSHClient()
        client.load_system_host_keys()
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        #client.set_missing_host_key_policy(paramiko.WarningPolicy())

        try:
            client.connect(**cfg)
        except BadHostKeyException as e:
            sys.stderr.write("got:      %s\n"%e.key.asbytes());
            sys.stderr.write("expected: %s\n"%e.expected_key.asbytes());

            msg = "Connection error on %s:%s using privatekey=%s\n"%(\
                host,port,private_key)
            msg += "you may need to clear ~/.ssh/known_hosts "+\
                "for entries related to %s:%s\n"%(host,port);
            sys.stderr.write(msg);
            raise Exception(msg);

        src = SSHClientSource()
        src.client = client
        src.ftp = client.open_sftp()

        src.host=host
        src.port=port
        src.config=cfg

        return src
Exemple #23
0
class _SSHConnection:
    """
    Helper class handling SFTP communication.

    If provided with a hostname, automatically initiates the SFTP session.

    :param hostname: The SSH host, defaults to :const:`None`.
    :type hostname: str, optional
    :param port: The port where the SSH host is listening, defaults to :const:`None`.
    :type port: int, optional
    :param username: The username on the target SSH host, defaults to :const:`None`.
    :type username: str, optional

    :ivar client: The SSH session handler.
    :vartype client: paramiko.client.SSHClient
    :ivar sftp_session: The SFTP session handler.
    :vartype sftp_session: paramiko.sftp_client.SFTPClient
    """

    client = None
    sftp_session = None

    def __init__(self,
                 hostname: str = None,
                 port: int = None,
                 username: str = None):
        """Initialize object."""
        if hostname:
            self.set_session(hostname, port, username)

    def set_session(self,
                    hostname: str,
                    port: int = None,
                    username: str = None):
        """
        Set up a SFTP session.

        :param str hostname: The SSH host.
        :param port: The port where the SSH host is listening, defaults to
          :const:`None`.
        :type port: int, optional
        :param username: The username on the target SSH host, defaults to :const:`None`.
        :type username: str, optional
        """
        self.client = SSHClient()
        self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

        if port and username:
            self.client.connect(hostname, port, username)
        elif port:
            self.client.connect(hostname, port)
        elif username:
            self.client.connect(hostname, username=username)
        else:
            self.client.connect(hostname)

        self.sftp_session = self.client.open_sftp()

    def close_session(self):
        """Close the SFTP session."""
        if self.is_active():
            self.client.close()

    def mkdir_p(self, path: str):
        """
        Simulate the :command:`mkdir -p <path>` Unix command.

        It creates the directory and all its non-existing parents.

        :param str path: The path to a directory.
        """
        dirs = []

        while len(path) > 1:
            dirs.append(path)
            path = os.path.dirname(path)

        if len(path) == 1 and not path.startswith("/"):
            dirs.append(path)

        while dirs:
            path = dirs.pop()
            try:
                self.sftp_session.stat(path)
            except IOError:
                self.sftp_session.mkdir(path, mode=0o755)

    # def write_file(self, content: str, path: str):
    #     """
    #     Create or overwrite a file under `path` and store `content` in it.
    #
    #     It creates the parent directories if required.
    #
    #     :param str content: The file content.
    #     :param str path: Path to the new file.
    #     """
    #     self.mkdir_p(os.path.normpath(os.path.dirname(path)))
    #
    #     with self.sftp_session.open(path, "w") as sftp_file:
    #         sftp_file.write(content)

    def put_file(self, local_path, remote_path):
        """
        Copy the local file at `local_path` into `remote_path`.

        It creates the parent directories if required.

        :param str local_path: Path to the local file.
        :param str remote_path: Path to the remote file.
        """
        self.mkdir_p(os.path.normpath(os.path.dirname(remote_path)))
        self.sftp_session.put(local_path, remote_path)

    def is_active(self) -> bool:
        """Check whether the SSH session is active or not.

        :return: `True` if the SSH session is active, `False` otherwise.
        :rtype: bool
        """
        if self.client:
            transport = self.client.get_transport()

            if transport is not None and transport.is_active():
                try:
                    transport.send_ignore()
                    return True
                except EOFError:
                    return False

        return False
class ExternalDataLoader:
    """
    Loads data from an external server
    """
    def __init__(self,
                 password,
                 scenes=[
                     "bathroom2", "car2", "classroom", "house", "room2",
                     "room3", "spaceship", "staircase"
                 ],
                 batch_size=64):
        self.scenes = scenes
        self.batch_size = batch_size

        self.client = SSHClient()
        self.client.load_system_host_keys()
        self.client.set_missing_host_key_policy(AutoAddPolicy())
        self.client.connect(hostname="143.248.38.66",
                            port=3202,
                            username="******",
                            password=password)

        # For Documentation see here: http://docs.paramiko.org/en/2.4/api/sftp.html#paramiko.sftp_client.SFTPClient
        self.sftp_client = self.client.open_sftp()
        self.sftp_client.chdir(
            "/home/siit/navi/data/input_data/deep_learning_denoising/renderings/"
        )

        # print(self.sftp_client.listdir())
        self.possible_spp = [128, 256, 512, 1024, 8192]  # 8192 is Ground Truth

        self.scene_files = OrderedDict()

        # Scan for total examples
        self.total_files = 0
        for scene_name in self.scenes:
            file_list = self.sftp_client.listdir(scene_name)
            file_amount = len(file_list)
            # print(self.sftp_client.listdir(scene_name))
            print("In {}: {}".format(scene_name, file_amount))
            self.total_files += file_amount
            self.scene_files[scene_name] = file_list

    @property
    def batches_amount(self):
        return math.ceil(self.files_amount / self.batch_size)

    @property
    def files_amount(self):
        return int(self.total_files / len(self.possible_spp))

    def get_data(self, idx, spp=128):
        """
        This returns an image and the corresponding GT
        """
        # Calculate scene
        if spp not in self.possible_spp:
            raise FileNotFoundError("There is no file for this spp.")

        if idx >= self.files_amount:
            raise IndexError("Index too big")

        total = 0
        scene_name = None
        for scene, file_list in self.scene_files.items():
            offset = int(len(file_list) / len(self.possible_spp))
            if idx > total + offset:
                total += offset
                continue
            scene_name = scene
            break

        all_file_spp = list(
            filter(lambda x: str(spp) in x, self.scene_files[scene_name]))
        all_file_spp.sort()

        sub_idx = idx - total
        file_name = all_file_spp[sub_idx]

        file_path_x = os.path.join(scene_name, file_name)
        file_x = self.get_file(file_path_x)

        file_id, _ = file_name.split("-")
        file_name_y = file_id + "08192.exr"
        file_path_y = os.path.join(scene_name, file_name)
        file_y = self.get_file(file_path_y)

        return file_x, file_y

    def get_batch(self, idx, spp=128):
        X = []
        Y = []

        if idx == self.batches_amount - 1:
            this_batch_size = self.files_amount % self.batch_size
        else:
            this_batch_size = self.batch_size

        for pos in range(this_batch_size):
            x, y = self.get_data(idx * self.batch_size + pos, spp=spp)
            X.append(x)
            Y.append(y)
        return X, Y

    def get_file(self, file_path):
        self.sftp_client.get(file_path, "image.exr")
        file = pyexr.open("image.exr")
        os.remove("image.exr")
        return file
class MySSHClient:
    def __init__(self):
        self.ssh_client = SSHClient()

    # 连接登录
    def connect(self, hostname, port, username, password):
        try:
            logger.info('正在远程连接主机:%s' % hostname)
            self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
            self.ssh_client.connect(hostname=hostname,
                                    port=port,
                                    username=username,
                                    password=password)
            return [True, '']
        except Exception as e:
            logger.error('连接出错了%s' % e)
            return [False, '%s' % e]

    # 列出目录文件 ls
    def listdir(self, path):
        # sftp = paramiko.SFTPClient.from_transport(self.ssh_client)
        try:
            sftp_client = self.ssh_client.open_sftp()
            # print('正在下载远程文件:%s 到本地:%s' % (remotepath, localpath))
            # sftp_client.get(remotepath, localpath)
            listdir = sftp_client.listdir(path)
            return listdir
        except Exception as e:
            raise e
        finally:
            sftp_client.close()

    # 远程执行命令
    def exec_command(self, command):
        try:
            logger.info('正在执行命令:' + command)
            stdin, stdout, stderr = self.ssh_client.exec_command(command)
            logger.info('命令输出:')
            logger.info(stdout.read())  # 读取命令输出
            return [True, tuple]
        except Exception as e:
            logger.error('执行命令: %s出错' % command)
            return [False, '%s' % e]

    # 下载文件(非目录文件)
    def download_file(self, remotepath, localpath):
        try:
            localpath = os.path.abspath(localpath)
            localpath = localpath.replace('\t', '/t').replace(
                '\n', '/n').replace('\r', '/r').replace('\b', '/b')  # 转换特殊字符
            localpath = localpath.replace('\f', '/f')
            logger.info('转换后的本地目标路径为:%s' % localpath)
            head, tail = os.path.split(localpath)
            if not tail:
                logger.warning('下载文件:%s 到本地:%s失败,本地文件名不能为空' %
                               (remotepath, localpath))
                return [
                    False,
                    '下载文件:%s 到本地:%s失败,本地文件名不能为空' % (remotepath, localpath)
                ]
            if not os.path.exists(head):
                logger.info('本地路径:%s不存在,正在创建目录' % head)
                OtherTools().mkdirs_once_many(head)

            sftp_client = self.ssh_client.open_sftp()
            logger.info('正在下载远程文件:%s 到本地:%s' % (remotepath, localpath))
            sftp_client.get(remotepath, localpath)
            sftp_client.close()
            return [True, '']
        except Exception as e:
            logger.error('下载文件:%s 到本地:%s 出错:%s' % (remotepath, localpath, e))
            return [False, '下载文件:%s 到本地:%s 出错:%s' % (remotepath, localpath, e)]

    # 下载文件(非目录文件)
    def remove_file(self, remotepath):
        try:
            sftp_client = self.ssh_client.open_sftp()
            sftp_client.remove(remotepath)
            logger.info('删除文件:%s 成功:%s' % (remotepath))
        except Exception as e:
            logger.error('删除文件:%s 出错:%s' % (remotepath, e))
            return [False, '删除文件:%s 出错:%s' % (remotepath, e)]

    # 上传文件(非目录文件)
    def upload_file(self, localpath, remotepath):
        try:
            localpath = localpath.rstrip('\\').rstrip('/')
            localpath = localpath.replace('\t', '/t').replace(
                '\n', '/n').replace('\r', '/r').replace('\b', '/b')  # 转换特殊字符
            localpath = localpath.replace('\f', '/f')
            localpath = os.path.abspath(localpath)
            logger.info('转换后的本地文件路径为:%s' % localpath)

            remotepath = remotepath.rstrip('\\').rstrip('/')
            head, tail = os.path.split(localpath)
            if not tail:
                logger.error('上传文件:%s 到远程:%s失败,本地文件名不能为空' %
                             (localpath, remotepath))
                return [
                    False,
                    '上传文件:%s 到远程:%s失败,本地文件名不能为空' % (localpath, remotepath)
                ]
            if not os.path.exists(head):
                logger.error('上传文件:%s 到远程:%s失败,父路径不存在' %
                             (localpath, remotepath, head))
                return [
                    False,
                    '上传文件:%s 到远程:%s失败,父路径不存在' % (localpath, remotepath, head)
                ]

            if not (remotepath.startswith('/') or remotepath.startswith('.')):
                logger.error('上传文件:%s 到远程:%s失败,远程路径填写不规范%s' %
                             (localpath, remotepath, remotepath))
                return [
                    False,
                    '上传文件:%s 到远程:%s失败,远程路径填写不规范%s' %
                    (localpath, remotepath, remotepath)
                ]
            sftp_client = self.ssh_client.open_sftp()
            head, tail = os.path.split(remotepath)

            head = sftp_client.normalize(head)  # 规范化路径
            remotepath = head + '/' + tail
            logger.info('规范化后的远程目标路径:', remotepath)

            logger.info('正在上传文件:%s 到远程:%s' % (localpath, remotepath))
            sftp_client.put(localpath, remotepath)
            sftp_client.close()
            return [True, '']
        except Exception as e:
            logger.error('上传文件:%s 到远程:%s 出错:%s' % (localpath, remotepath, e))
            return [False, '上传文件:%s 到远程:%s 出错:%s' % (localpath, remotepath, e)]

    def close(self):
        self.ssh_client.close()
Exemple #26
0
    def cron_ssh_move_documents(
        self,
        host=False,
        port=False,
        user=False,
        password=False,
        ssh_path=False,
    ):
        dest_path = (self.env["ir.config_parameter"].sudo().get_param(
            "hash_search_document_scanner.path", default=False))
        connection = SSHClient()
        connection.load_system_host_keys()

        if not dest_path:
            return False
        if not host:
            host = self.env["ir.config_parameter"].get_param(
                "hash_search_document_scanner_queue_ssh.host", default=False)
        if not port:
            port = int(self.env["ir.config_parameter"].get_param(
                "hash_search_document_scanner_queue_ssh.port", default="0"))
        if not user:
            user = self.env["ir.config_parameter"].get_param(
                "hash_search_document_scanner_queue_ssh.user", default=False)
        if not password:
            password = self.env["ir.config_parameter"].get_param(
                "hash_search_document_scanner_queue_ssh.password",
                default=False,
            )

        if not ssh_path:
            ssh_path = self.env["ir.config_parameter"].get_param(
                "hash_search_document_scanner_queue_ssh.ssh_path",
                default=False,
            )
        connection.connect(hostname=host,
                           port=port,
                           username=user,
                           password=password)
        sftp = connection.open_sftp()
        if ssh_path:
            sftp.chdir(ssh_path)
        elements = sftp.listdir_attr(".")
        min_time = int(time.time()) - 60
        single_commit = self.env.context.get("scanner_single_commit", False)
        for element in elements:
            if element.st_atime > min_time and not self.env.context.get(
                    "scanner_ignore_time", False):
                continue
            filename = element.filename
            new_element = os.path.join(dest_path, filename)
            if not single_commit:
                new_cr = Registry(self.env.cr.dbname).cursor()
            try:
                sftp.get(filename, new_element)
                if single_commit:
                    obj = self.env[self._name].browse()
                else:
                    obj = (api.Environment(
                        new_cr, self.env.uid,
                        self.env.context)[self._name].browse().with_delay())
                obj.process_document(new_element)
                if not single_commit:
                    new_cr.commit()
            except Exception:
                if os.path.exists(new_element):
                    os.unlink(new_element)
                if not single_commit:
                    new_cr.rollback()  # error, rollback everything atomically
                raise
            finally:
                if not single_commit:
                    new_cr.close()
            sftp.remove(element.filename)
        sftp.close()
        connection.close()
        return True
Exemple #27
0
class SSHperformer(Performer):
    provider_name = 'ssh'
    settings_class = SSHPerformerSettings

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client = None
        self.logger = getLogger(__name__)

    def _connect(self):
        self.client = SSHClient()
        self.client.set_missing_host_key_policy(AutoAddPolicy())
        self.client.load_system_host_keys()

        connection_details = {}
        if self.settings.port:
            connection_details['port'] = self.settings.port
        if self.settings.username:
            connection_details['username'] = self.settings.username
        if self.settings.password:
            connection_details['password'] = self.settings.password
        try:
            self.client.connect(self.settings.hostname, **connection_details)
        except NoValidConnectionsError as e:
            raise PerformerError('Cant connect to %s' % self.settings.hostname)
        else:
            #ssh agent forwarding
            s = self.client.get_transport().open_session()
            AgentRequestHandler(s)

    def _paramiko_exec_command(self, command, bufsize=-1, timeout=None):
        # replacement paramiko.client.exec_command(command) for binary output
        # https://github.com/paramiko/paramiko/issues/291
        # inspired by workaround https://gist.github.com/smurn/4d45a51b3a571fa0d35d

        chan = self.client._transport.open_session(timeout=timeout)
        chan.settimeout(timeout)
        chan.exec_command(command)
        stdin = chan.makefile('wb', bufsize)
        stdout = chan.makefile('rb', bufsize)
        stderr = chan.makefile_stderr('rb', bufsize)
        return stdin, stdout, stderr

    def execute(self, command, logger=None, writein=None, max_lines=None):
        self.logger.debug("Execute command: '%s'" % command)
        if not self.client:
            self._connect()

        stdin, stdout, stderr = self._paramiko_exec_command(command)

        # read stdout asynchronously - in 'realtime'
        output_reader = OutputReader(stdout, logger=logger or self.output_logger, max_lines=max_lines)

        if writein:
            # write writein to stdin
            stdin.write(writein)
            stdin.flush()
            stdin.channel.shutdown_write()

        # wait for end of output
        output = output_reader.output()

        # wait for exit code
        exit_code = stdout.channel.recv_exit_status()

        if exit_code:
            err = stderr.read().decode('utf-8').strip()
            self.logger.debug('command error: %s' % err)
            raise CommandError(command, exit_code, err)

        return output

    def send_file(self, source, target):
        self.logger.debug("Send file: '%s' '%s'" % (source, target))
        source = expanduser(source)
        sftp = self.client.open_sftp()
        sftp.put(source, target)
        sftp.close()

    @contextmanager
    def get_fo(self, remote_path):
        from tempfile import SpooledTemporaryFile
        self.logger.debug('SSH Get fo: %s' % remote_path)
        sftp = self.client.open_sftp()
        try:
            with SpooledTemporaryFile(1024000) as fo:
                sftp.getfo(remote_path, fo)
                yield fo
        finally:
            sftp.close()
Exemple #28
0
class SFTPClient:
    MAX_PACKET_SIZE = SFTPFile.__dict__['MAX_REQUEST_SIZE']

    ssh_client = None
    client = None
    raise_exceptions = False
    original_arguments = {}
    debug = False

    _log = logging.getLogger(LOG_NAME)
    _dircache = []

    def __init__(self, **kwargs):
        self.original_arguments = kwargs.copy()
        self._connect(**kwargs)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close_all()

    def _connect(self, **kwargs):
        kwargs_to_paramiko = dict(
            look_for_keys=kwargs.pop('look_for_keys', True),
            username=kwargs.pop('username'),
            port=kwargs.pop('port', 22),
            allow_agent=False,
            timeout=kwargs.pop('timeout', None),
        )
        host = kwargs.pop('hostname', 'localhost')
        password = kwargs.pop('password')
        keepalive = kwargs.pop('keepalive', 5)
        if password:
            kwargs_to_paramiko['password'] = password
        self.raise_exceptions = kwargs.pop('raise_exceptions', False)

        self.ssh_client = SSHClient()
        self.ssh_client.load_system_host_keys()
        self.ssh_client.connect(host, **kwargs_to_paramiko)

        self.client = self.ssh_client.open_sftp()
        channel = self.client.get_channel()
        channel.settimeout(kwargs_to_paramiko['timeout'])
        channel.get_transport().set_keepalive(keepalive)

        # 'Extend' the SFTPClient class
        is_reconnect = kwargs.pop('is_reconnect', False)
        members = inspect.getmembers(self.client,
                                     predicate=inspect.ismethod)
        self._log.debug('Dynamically adding methods from original SFTPClient')
        for (method_name, method) in members:
            if method_name[0:2] == '__' or method_name == '_log':
                self._log.debug('Ignorning {}()'.format(method_name))
                continue

            if not is_reconnect and hasattr(self, method_name):
                raise AttributeError('Not overwriting property "{}". This '
                                     'version of Paramiko is not '
                                     'supported.'.format(method_name))

            self._log.debug('Adding method {}()'.format(method_name))
            setattr(self, method_name, method)

    def close_all(self):
        self.client.close()
        self.ssh_client.close()

    def clear_directory_cache(self):
        self._dircache = []

    def listdir_attr_recurse(self, path='.'):
        for da in self.client.listdir_attr(path=path):
            is_dir = da.st_mode & 0o700 == 0o700
            if is_dir:
                try:
                    yield from self.listdir_attr_recurse(
                        path_join(path, da.filename))
                except IOError as e:
                    if self.raise_exceptions:
                        raise e
            else:
                yield (path_join(path, da.filename), da,)

    def mirror(self,
               path='.',
               destroot='.',
               keep_modes=True,
               keep_times=True,
               resume=True):
        n = 0
        resume_seek = None
        cwd = self.getcwd()

        for _path, info in self.listdir_attr_recurse(path=path):
            if info.st_mode & 0o700 == 0o700:
                continue

            dest_path = path_join(destroot, dirname(_path))
            dest = path_join(dest_path, basename(_path))

            if dest_path not in self._dircache:
                try:
                    makedirs(dest_path)
                except FileExistsError:
                    pass
                self._dircache.append(dest_path)

            if isdir(dest):
                continue

            try:
                with open(dest, 'rb'):
                    current_size = os.stat(dest).st_size

                    if current_size != info.st_size:
                        resume_seek = current_size
                        if resume:
                            self._log.info('Resuming file {} at {} '
                                           'bytes'.format(dest, current_size))
                        raise IOError()  # ugly goto
            except IOError as ioe:
                while True:
                    try:
                        # Only size is used to determine complete-ness here
                        # Hash verification is in the util module
                        if resume_seek and resume:
                            read_tuples = []

                            n_reads = ceil((info.st_size - resume_seek) / self.MAX_PACKET_SIZE) - 1
                            n_left = (info.st_size - resume_seek) % self.MAX_PACKET_SIZE
                            offset = 0

                            for n in range(n_reads):
                                read_tuples.append((resume_seek + offset, self.MAX_PACKET_SIZE,))
                                offset += self.MAX_PACKET_SIZE
                            read_tuples.append((resume_seek + offset, n_left,))

                            with self.client.open(_path) as rf:
                                with open(dest, 'ab') as f:
                                    f.seek(resume_seek)
                                    resume_seek = None

                                    for chunk in rf.readv(read_tuples):
                                        f.write(chunk)
                        else:
                            self._log.info('Downloading {} -> '
                                           '{}'.format(_path, dest))
                            self.client.get(_path, dest)

                        # Do not count files that were already downloaded
                        n += 1

                        break
                    except (socket.timeout, SFTPError) as e:
                        # Resume at position - 10 bytes
                        resume_seek = os.stat(dest).st_size - 10
                        if isinstance(e, socket.timeout):
                            self._log.error('Connection timed out')
                        else:
                            self._log.error('{!s}'.format(e))

                        if resume:
                            self._log.info('Resuming GET {} at {} '
                                           'bytes'.format(_path,
                                                          resume_seek))
                        else:
                            self._log.debug('Not resuming (resume = {}, exception: {})'.format(resume, e))
                            raise e

                        self._log.debug('Re-establishing connection')
                        self.original_arguments['is_reconnect'] = True
                        self._connect(**self.original_arguments)
                        if cwd:
                            self.chdir(cwd)

            # Okay to fix existing files even if they are already downloaded
            try:
                if keep_modes:
                    chmod(dest, info.st_mode)
                if keep_times:
                    utime(dest, (info.st_atime, info.st_mtime,))
            except IOError:
                pass

        return n

    def __str__(self):
        return '{} (wrapped by {}.SFTPClient)'.format(
            str(self.client), __name__)
    __unicode__ = __str__
Exemple #29
0
    def archive_service(self, service):
        """Actually do the archiving step for the given Service
        """

        # Create the base directory for this service, i.e. where we put logs.
        base_dir = os.path.join(self.base_dir, service.name, service.host)
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        if "<DATE->" not in service.pattern:
            # We ignore services that don't have a <DATE-> in their pattern
            print "Warning:", service.name, "does not include date. Ignoring."

        # Connect to remote
        client = SSHClient()
        # TODO: Use something other than auto add policy?
        client.set_missing_host_key_policy(AutoAddPolicy())
        client.connect(
            service.host,
            username=service.account,
            compress=True,
            allow_agent=self.use_ssh_agent,
        )

        # Fetch list of files from the remote
        glob = service.pattern.replace("<DATE->", "????-??-??")
        cmd = FIND_COMMAND_TEMPLATE % {
            "dir": service.directory,
            "glob": glob,
        }
        _, stdout, _ = client.exec_command(cmd)
        files = stdout.readlines()
        files[:] = list(f.strip() for f in files)
        files.sort()

        # Filter the files to ones we want to archive
        files = filter_by_age(
            files,
            lambda d: d.days > service.days_to_keep_on_remote
        )

        # For each file download to a pending file name (optionally gzipping)
        # and only after it has succesfully been downloaded do we optionally
        # delete from the remote.
        sftp = client.open_sftp()
        for file_name in files:
            local_name = os.path.join(base_dir, os.path.basename(file_name))
            if not file_name.endswith(".gz"):
                local_name += ".gz"
            pending_name = local_name + ".download"

            if os.path.exists(pending_name):
                os.remove(pending_name)

            if os.path.exists(local_name):
                print "Warning: ", local_name, "already exists"
                continue

            # Set up progress bar for downloads
            if self.verbose:
                widgets = [
                    os.path.basename(file_name), " ",
                    progressbar.Percentage(),
                    ' ', progressbar.Bar(),
                    ' ', progressbar.ETA(),
                    ' ', progressbar.FileTransferSpeed(),
                ]
                pb = progressbar.ProgressBar(widgets=widgets)

                def progress_cb(bytes_downloaded, total_size):
                    pb.max_value = total_size
                    pb.update(bytes_downloaded)
            else:
                def progress_cb(bytes_downloaded, total_size):
                    pass

            if self.verbose or self.dry_run:
                print "Archiving: %s:%s to %s" % (
                    service.host, file_name, local_name,
                )

            if not self.dry_run:
                # If filename does not end with '.gz' then we compress while
                # we download
                # TODO: Should we be preserving last modified times?
                if not file_name.endswith(".gz"):
                    with gzip.open(pending_name, 'wb', compresslevel=9) as f:
                        sftp.getfo(file_name, f, callback=progress_cb)
                else:
                    sftp.get(file_name, pending_name, callback=progress_cb)

                if self.verbose:
                    pb.finish()

                os.rename(pending_name, local_name)

                if self.remove:
                    if self.verbose:
                        print "Removing remote"
                    sftp.remove(file_name)

        sftp.close()
        client.close()

        # We now go and delete any files that are older than the retention
        # period, if specified
        if service.retention_period_days:
            local_files = list(
                os.path.join(dirpath, filename)
                for dirpath, _, filenames in os.walk(base_dir)
                for filename in filenames
            )

            files_to_delete = filter_by_age(
                local_files,
                lambda d: d.days > service.retention_period_days
            )

            for file_name in files_to_delete:
                if self.verbose or self.dry_run:
                    print "Deleting file due to retention policy: %s" % (
                        file_name,
                    )

                if not self.dry_run:
                    os.remove(file_name)
Exemple #30
0
class RemoteClient(object):
    """Remote Client is a wrapper over SSHClient with utility functions.

    Args:
        host (string): The hostname of the server to connect. It can be an IP
            address of the server also.
        user (string, optional): The user to connect to the remote server. It
            defaults to root

    Attributes:
        host (string): The hostname passed in as a the argument
        user (string): The user to connect as to the remote server
        client (:class:`paramiko.client.SSHClient`): The SSHClient object used
            for all the communications with the remote server.
        sftpclient (:class:`paramiko.sftp_client.SFTPClient`): The SFTP object
            for all the file transfer operations over the SSH.
    """
    def __init__(self, host, ip=None, user='******'):
        self.host = host
        self.ip = ip
        self.user = user
        self.client = SSHClient()
        self.sftpclient = None
        self.client.set_missing_host_key_policy(AutoAddPolicy())
        self.client.load_system_host_keys()
        logging.debug("RemoteClient created for host: %s", host)

    def startup(self):
        """Function that starts SSH connection and makes client available for
        carrying out the functions. It tries with the hostname, if it fails
        it tries with the IP address if supplied
        """
        try:
            logging.debug("Trying to connect to remote server %s", self.host)
            self.client.connect(self.host, port=22, username=self.user)
            self.sftpclient = self.client.open_sftp()
        except (SSHException, socket.error):
            if self.ip:
                logging.warning("Connection with hostname failed. Retrying "
                                "with IP")
                self._try_with_ip()
            else:
                logging.error("Connection to %s failed.", self.host)
                raise ClientNotSetupException('Could not connect to the host.')

    def _try_with_ip(self):
        try:
            logging.debug("Connecting to IP:%s User:%s", self.ip, self.user)
            self.client.connect(self.ip, port=22, username=self.user)
            self.sftpclient = self.client.open_sftp()
        except (SSHException, socket.error):
            logging.error("Connection with IP (%s) failed.", self.ip)
            raise ClientNotSetupException('Could not connect to the host.')

    def download(self, remote, local):
        """Downloads a file from remote server to the local system.

        Args:
            remote (string): location of the file in remote server
            local (string): path where the file should be saved
        """
        if not self.sftpclient:
            raise ClientNotSetupException(
                'Cannot download file. Client not initialized')

        try:
            self.sftpclient.get(remote, local)
            return "Download successful. File at: {0}".format(local)
        except OSError:
            return "Error: Local file %s doesn't exist." % local
        except IOError:
            return "Error: Remote location %s doesn't exist." % remote

    def upload(self, local, remote):
        """Uploads the file from local location to remote server.

        Args:
            local (string): path of the local file to upload
            remote (string): location on remote server to put the file
        """
        if not self.sftpclient:
            raise ClientNotSetupException(
                'Cannot upload file. Client not initialized')

        try:
            self.sftpclient.put(local, remote)
            return "Upload successful. File at: {0}".format(remote)
        except OSError:
            return "Error: Local file %s doesn't exist." % local
        except IOError:
            return "Error: Remote location %s doesn't exist." % remote

    def exists(self, filepath):
        """Returns whether a file exists or not in the remote server.

        Args:
            filepath (string): path to the file to check for existance

        Returns:
            True if it exists, False if it doesn't
        """
        if not self.client:
            raise ClientNotSetupException(
                'Cannot run procedure. Client not initialized')
        cin, cout, cerr = self.client.exec_command('stat {0}'.format(filepath))
        if len(cout.read()) > 5:
            return True
        elif len(cerr.read()) > 5:
            return False

    def run(self, command):
        """Run a command in the remote server.

        Args:
            command (string): the command to be run on the remote server

        Returns:
            tuple of three strings containing text from stdin, stdout an stderr
        """
        if not self.client:
            raise ClientNotSetupException(
                'Cannot run procedure. Client not initialized')

        #buffers = self.client.exec_command(command, timeout=30)
        buffers = self.client.exec_command(command)
        output = []
        for buf in buffers:
            try:
                output.append(buf.read())
            except IOError:
                output.append('')

        return tuple(output)

    def get_file(self, filename):
        """Reads content of filename on remote server

        Args:
            filename (string): name of file to be read from remote server

        Returns:
            tuple: True/False, file like object / error
        """
        f = StringIO.StringIO()
        try:
            r = self.sftpclient.getfo(filename, f)
            f.seek(0)
            return r, f
        except Exception as err:
            return False, err

    def put_file(self, filename, filecontent):
        """Puts content to a file on remote server

        Args:
            filename (string): name of file to be written on remote server
            filecontent (string): content of file

        Returns:
            tuple: True/False, file size / error
        """
        f = StringIO.StringIO()
        f.write(filecontent)
        f.seek(0)

        try:
            r = self.sftpclient.putfo(f, filename)
            return True, r.st_size
        except Exception as err:
            return False, err

    def mkdir(self, dirname):
        """Creates a new directory.

        Args:
            dirname (string): the full path of the directory that needs to be
                created

        Returns:
            a tuple containing the success or failure of operation and dirname
                on success and error on failure
        """
        try:
            self.sftpclient.mkdir(dirname)
            return True, dirname
        except Exception as err:
            return False, err

    def listdir(self, dirname):
        """Lists all the files and folders in a directory.

        Args:
            dirname (string): the full path of the directory that needs to be
                listed

        Returns:
            a list of the files and folders in the directory
        """
        try:
            r = self.sftpclient.listdir(dirname)
            return True, r
        except Exception as err:
            return False, err

    def close(self):
        """Close the SSH Connection
        """
        self.client.close()

    def __repr__(self):
        return "RemoteClient({0}, ip={1}, user={2})".format(
            self.host, self.ip, self.user)
Exemple #31
0
class SFTPClient(object):
    """Dynamic extension on paramiko's SFTPClient."""

    MAX_PACKET_SIZE = SFTPFile.__dict__['MAX_REQUEST_SIZE']

    ssh_client = None
    client = None
    raise_exceptions = False
    original_arguments = {}
    debug = False

    _log = logging.getLogger(LOG_NAME)
    _dircache = []

    def __init__(self, **kwargs):
        """Constructor."""
        self.original_arguments = kwargs.copy()
        self._connect(**kwargs)

    def __enter__(self):
        """For use with a with statement."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """For use with a with statement."""
        self.close_all()

    def _connect(self, **kwargs):
        kwargs_to_paramiko = dict(
            look_for_keys=kwargs.pop('look_for_keys', True),
            username=kwargs.pop('username'),
            port=kwargs.pop('port', 22),
            allow_agent=False,
            timeout=kwargs.pop('timeout', None),
        )
        host = kwargs.pop('hostname', 'localhost')
        password = kwargs.pop('password')
        keepalive = kwargs.pop('keepalive', 5)
        if password:
            kwargs_to_paramiko['password'] = password
        self.raise_exceptions = kwargs.pop('raise_exceptions', False)

        self.ssh_client = SSHClient()
        self.ssh_client.load_system_host_keys()
        self.ssh_client.connect(host, **kwargs_to_paramiko)

        self.client = self.ssh_client.open_sftp()
        channel = self.client.get_channel()
        channel.settimeout(kwargs_to_paramiko['timeout'])
        channel.get_transport().set_keepalive(keepalive)

        # 'Extend' the SFTPClient class
        is_reconnect = kwargs.pop('is_reconnect', False)
        members = inspect.getmembers(self.client,
                                     predicate=inspect.ismethod)
        self._log.debug('Dynamically adding methods from original SFTPClient')
        for (method_name, method) in members:
            if method_name[0:2] == '__' or method_name == '_log':
                self._log.debug('Ignorning {}()'.format(method_name))
                continue

            if not is_reconnect and hasattr(self, method_name):
                raise AttributeError('Not overwriting property "{}". This '
                                     'version of Paramiko is not '
                                     'supported.'.format(method_name))

            self._log.debug('Adding method {}()'.format(method_name))
            setattr(self, method_name, method)

    def close_all(self):
        """Close client and SSH client handles."""
        self.client.close()
        self.ssh_client.close()

    def clear_directory_cache(self):
        """Reset directory cache."""
        self._dircache = []

    def listdir_attr_recurse(self, path='.'):
        """List directory attributes recursively."""
        for da in self.client.listdir_attr(path=path):
            is_dir = da.st_mode & 0o700 == 0o700
            if is_dir:
                try:
                    for x in self.listdir_attr_recurse(
                            path_join(path, da.filename)):
                        yield x
                except IOError as e:
                    if self.raise_exceptions:
                        raise e
            else:
                yield (path_join(path, da.filename), da,)

    def _get_callback(self, start_time, _log):
        def cb(tx_bytes, total_bytes):
            total_time = datetime.now() - start_time
            total_time = total_time.total_seconds()
            total_time_s = floor(total_time)

            if (total_time_s % LOG_INTERVAL) != 0:
                return

            nsize_tx = naturalsize(tx_bytes,
                                   binary=True,
                                   format='%.2f')
            nsize_total = naturalsize(total_bytes,
                                      binary=True,
                                      format='%.2f')

            speed_in_s = tx_bytes / total_time
            speed_in_s = naturalsize(speed_in_s,
                                     binary=True,
                                     format='%.2f')

            _log.info('Downloaded {} / {} in {} ({}/s)'.format(
                nsize_tx,
                nsize_total,
                naturaldelta(datetime.now() - start_time),
                speed_in_s,
                total_time_s))

        return cb

    def mirror(self,
               path='.',
               destroot='.',
               keep_modes=True,
               keep_times=True,
               resume=True):
        """
        Mirror a remote directory to a local location.

        path is the remote directory. destroot must be the location where
        destroot/path will be created (the path must not already exist).

        keep_modes and keep_times are boolean to ensure permissions and time
        are retained respectively.

        Pass resume=False to disable file resumption.
        """
        n = 0
        resume_seek = None
        cwd = self.getcwd()

        for _path, info in self.listdir_attr_recurse(path=path):
            if info.st_mode & 0o700 == 0o700:
                continue

            dest_path = path_join(destroot, dirname(_path))
            dest = path_join(dest_path, basename(_path))

            if dest_path not in self._dircache:
                try:
                    makedirs(dest_path)
                except OSError:
                    pass
                self._dircache.append(dest_path)

            if isdir(dest):
                continue

            try:
                with open(dest, 'rb'):
                    current_size = os.stat(dest).st_size

                    if current_size != info.st_size:
                        resume_seek = current_size
                        if resume:
                            self._log.info('Resuming file {} at {} '
                                           'bytes'.format(dest, current_size))
                        raise IOError()  # ugly goto
            except IOError:
                while True:
                    try:
                        # Only size is used to determine complete-ness here
                        # Hash verification is in the util module
                        if resume_seek and resume:
                            read_tuples = []

                            n_reads = ceil((info.st_size - resume_seek) /
                                           self.MAX_PACKET_SIZE) - 1
                            n_left = ((info.st_size - resume_seek) %
                                      self.MAX_PACKET_SIZE)
                            offset = 0

                            for n in range(n_reads):
                                read_tuples.append((resume_seek + offset,
                                                    self.MAX_PACKET_SIZE,))
                                offset += self.MAX_PACKET_SIZE
                            read_tuples.append((resume_seek + offset, n_left,))

                            with self.client.open(_path) as rf:
                                with open(dest, 'ab') as f:
                                    f.seek(resume_seek)
                                    resume_seek = None

                                    for chunk in rf.readv(read_tuples):
                                        f.write(chunk)
                        else:
                            dest = realpath(dest)
                            self._log.info('Downloading {} -> '
                                           '{}'.format(_path, dest))

                            start_time = datetime.now()
                            self.client.get(_path, dest)

                            self._get_callback(start_time, self._log)(
                                info.st_size, info.st_size)

                        # Do not count files that were already downloaded
                        n += 1

                        break
                    except (socket.timeout, SFTPError) as e:
                        # Resume at position - 10 bytes
                        resume_seek = os.stat(dest).st_size - 10
                        if isinstance(e, socket.timeout):
                            self._log.error('Connection timed out')
                        else:
                            self._log.error('{!s}'.format(e))

                        if resume:
                            self._log.info('Resuming GET {} at {} '
                                           'bytes'.format(_path,
                                                          resume_seek))
                        else:
                            self._log.debug('Not resuming (resume = {}, '
                                            'exception: {})'.format(resume,
                                                                    e))
                            raise e

                        self._log.debug('Re-establishing connection')
                        self.original_arguments['is_reconnect'] = True
                        self._connect(**self.original_arguments)
                        if cwd:
                            self.chdir(cwd)

            # Okay to fix existing files even if they are already downloaded
            try:
                if keep_modes:
                    chmod(dest, info.st_mode)
                if keep_times:
                    utime(dest, (info.st_atime, info.st_mtime,))
            except IOError:
                pass

        return n

    def __str__(self):
        """Return string representation."""
        return '{} (wrapped by {}.SFTPClient)'.format(
            str(self.client), __name__)
    __unicode__ = __str__
Exemple #32
0
class Libimobiledevice:
    def __init__(self):
        self.ssh_client = SSHClient()
        self.ssh_client.set_missing_host_key_policy(AutoAddPolicy())
        self.ssh_client.connect(MAC_HOST_IP,
                                port=MAC_PORT,
                                username=MAC_USERNAME,
                                password=MAC_PASSWORD,
                                timeout=3)
        self.ssh_dict = {}

    def device_list(self):
        result = self.execute_some_command("idevice_id -l")
        if result.strip():
            return result.strip().split('\n')
        else:
            return []

    def screenshot(self, uuid):
        result = self.execute_some_command(
            f"cd {WORK_DIR};idevicescreenshot -u {uuid} {uuid}.png")
        result = result.replace('Screenshot saved to ', '').strip()
        return WORK_DIR + result

    def screenshot_device_info_then_upload(self, uuid, to):
        remote_file_path = self.screenshot(uuid)
        try:
            device_info = self.device_info(uuid)
        except:
            device_info = None
        file_info = self.curl_upload_file(
            to, remote_file_path,
            json.dumps(device_info, ensure_ascii=False).replace('"', '\\"'))

        try:
            file_info = json.loads(file_info).get('data')
        except:
            file_info = None
        return {
            "device_info": device_info,
            "file_info": file_info,
        }

    def syslog_start(self, uuid):
        if uuid not in self.device_list():
            return False
        ssh = self.ssh_client.invoke_shell()
        ssh.send(f"cd {WORK_DIR};idevicesyslog -u {uuid}>{uuid}.txt\n")
        self.ssh_dict[uuid] = ssh
        return True

    def syslog_stop(self, uuid):
        self.ssh_dict[uuid].send(chr(3))
        return WORK_DIR + uuid + '.txt'

    def syslog_device_info_then_upload(self, uuid, to):
        remote_file_path = self.syslog_stop(uuid)
        try:
            device_info = self.device_info(uuid)
        except:
            device_info = None
        file_info = self.curl_upload_file(
            to, remote_file_path,
            json.dumps(device_info, ensure_ascii=False).replace('"', '\\"'))

        try:
            file_info = json.loads(file_info).get('data')
        except:
            file_info = None
        return {
            "device_info": device_info,
            "file_info": file_info,
        }

    def device_info(self, uuid):
        result = self.execute_some_command(f"ideviceinfo -u {uuid}")
        result += self.execute_some_command(
            f"ideviceinfo -u {uuid} -q com.apple.disk_usage.factory")
        result += self.execute_some_command(
            f"ideviceinfo -u {uuid} -q com.apple.mobile.battery")
        # print(result)
        return {
            "DeviceName":
            re.findall('DeviceName: (.*?)\n', result, re.I)[0],
            "ProductType":
            re.findall('ProductType: (.*?)\n', result, re.I)[0],
            "ProductVersion":
            re.findall('ProductVersion: (.*?)\n', result, re.I)[0],
            "BatteryCurrentCapacity":
            re.findall('BatteryCurrentCapacity: (.*?)\n', result, re.I)[0],
            "BatteryIsCharging":
            re.findall('BatteryIsCharging: (.*?)\n', result, re.I)[0],
            "TotalDataCapacity":
            re.findall('TotalDataCapacity: (.*?)\n', result, re.I)[0],
            "TotalDataAvailable":
            re.findall('TotalDataAvailable: (.*?)\n', result, re.I)[0],
        }

    def execute_some_command(self, command):
        stdin, stdout, stderr = self.ssh_client.exec_command(
            f"bash -lc '{command}'", timeout=10)
        return stdout.read().decode()

    def ssh_logout(self):
        self.ssh_client.close()

    def curl_upload_file(self, username, filepath, device_info):
        result = self.execute_some_command(
            f'curl {DOMAIN}/api/tmt/files/ -F "file=@{filepath}" -F "username={username}" -F "device_info={device_info}"'
        )
        if result:
            return result
        else:
            return "{}"

    def upload_file(self, local_file_path, remote_file_path):
        """
        上传文件
        """
        # 创建sftp对象上传文件
        sftp = self.ssh_client.open_sftp()
        sftp.put(local_file_path, remote_file_path)
        sftp.close()

    def download_file(self, remote_file_path, local_file_path):
        """
        下载文件
        """
        # 创建sftp对象下载文件
        sftp = self.ssh_client.open_sftp()
        sftp.get(remote_file_path, local_file_path)
        sftp.close()
Exemple #33
0
class Client:

    context = {}

    def __init__(self, hostname, configpath=None, dry_run=False):
        ssh_config = SSHConfig()
        if not hostname:
            print(red('"hostname" must be defined'))
            sys.exit(1)
        parsed = self.parse_host(hostname)
        hostname = parsed.get('hostname')
        username = parsed.get('username')
        if configpath:
            if not isinstance(configpath, (list, tuple)):
                configpath = [configpath]
            for path in configpath:
                self._load_config(path, hostname)
        with (Path.home() / '.ssh/config').open() as fd:
            ssh_config.parse(fd)
        ssh_config = ssh_config.lookup(hostname)
        self.dry_run = dry_run
        self.hostname = config.hostname or ssh_config['hostname']
        self.username = (username or config.username
                         or ssh_config.get('user', getuser()))
        self.formatter = Formatter()
        self.key_filenames = []
        if config.key_filename:
            self.key_filenames.append(config.key_filename)
        if 'identityfile' in ssh_config:
            self.key_filenames.extend(ssh_config['identityfile'])
        self.sudo = ''
        self.cd = None
        self.screen = None
        self.env = {}
        self._sftp = None
        self.proxy_command = ssh_config.get('proxycommand',
                                            config.proxy_command)
        self.open()

    def open(self):
        self._client = SSHClient()
        self._client.load_system_host_keys()
        self._client.set_missing_host_key_policy(WarningPolicy())
        print(f'Connecting to {self.username}@{self.hostname}')
        if self.proxy_command:
            print('ProxyCommand:', self.proxy_command)
        sock = (paramiko.ProxyCommand(self.proxy_command)
                if self.proxy_command else None)
        try:
            self._client.connect(hostname=self.hostname,
                                 username=self.username,
                                 sock=sock,
                                 key_filename=self.key_filenames)
        except paramiko.ssh_exception.BadHostKeyException:
            sys.exit('Connection error: bad host key')
        self._transport = self._client.get_transport()

    def close(self):
        print(f'\nDisconnecting from {self.username}@{self.hostname}')
        self._client.close()

    def _load_config(self, path, hostname):
        with Path(path).open() as fd:
            conf = yaml.load(fd)
            if hostname in conf:
                conf.update(conf[hostname])
            config.update(conf)

    def parse_host(self, host_string):
        user_hostport = host_string.rsplit('@', 1)
        hostport = user_hostport.pop()
        user = user_hostport[0] if user_hostport and user_hostport[0] else None

        # IPv6: can't reliably tell where addr ends and port begins, so don't
        # try (and don't bother adding special syntax either, user should avoid
        # this situation by using port=).
        if hostport.count(':') > 1:
            host = hostport
            port = None
        # IPv4: can split on ':' reliably.
        else:
            host_port = hostport.rsplit(':', 1)
            host = host_port.pop(0) or None
            port = host_port[0] if host_port and host_port[0] else None

        if port is not None:
            port = int(port)

        return {'username': user, 'hostname': host, 'port': port}

    def _build_command(self, cmd, **kwargs):
        prefix = ''
        if self.cd:
            cmd = f'cd {self.cd}; {cmd}'
        if self.env:
            prefix = ' '.join(f'{k}={v}' for k, v in self.env.items())
        if self.sudo:
            prefix = f'{self.sudo} {prefix}'
        cmd = self.format(f"{prefix} sh -c $'{cmd}'")
        if self.screen:
            cmd = f'screen -UD -RR -S {self.screen} {cmd}'
        return cmd.strip().replace('  ', ' ')

    def _call_command(self, cmd, **kwargs):
        channel = self._transport.open_session()
        try:
            size = os.get_terminal_size()
        except IOError:
            channel.get_pty()  # Fails when ran from pytest.
        else:
            channel.get_pty(width=size.columns, height=size.lines)
        channel.exec_command(cmd)
        channel.setblocking(False)  # Allow to read from empty buffer.
        stdout = channel.makefile('r', -1)
        stderr = channel.makefile_stderr('r', -1)
        proxy_stdout = b''
        buf = b''
        while True:
            while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                # TODO compute bytes_to_read like in invoke?
                data = sys.stdin.read(1)
                if data:
                    channel.sendall(data)
                else:
                    break
            if not channel.recv_ready():
                if buf:  # We may have read some buffer yet, let's output it.
                    sys.stdout.write(buf.decode())
                    sys.stdout.flush()
                    buf = b''
                if channel.exit_status_ready():
                    break
                continue
            try:
                data = stdout.read(1)
            except Exception:  # Not sure how to catch socket.timeout properly.
                pass
            else:
                proxy_stdout += data
                buf += data
                if data == b'\n':
                    sys.stdout.write(buf.decode())
                    sys.stdout.flush()
                    buf = b''
                continue
            time.sleep(paramiko.io_sleep)
        channel.setblocking(True)  # Make sure we now wait for stderr.
        ret = Status(proxy_stdout.decode(),
                     stderr.read().decode().strip(),
                     channel.recv_exit_status())
        channel.close()
        if ret.code:
            self.exit(ret.stderr, ret.code)
        return ret

    def exit(self, msg, code=1):
        print(red(msg))
        sys.exit(code)

    def __call__(self, cmd, **kwargs):
        cmd = self._build_command(cmd, **kwargs)
        print(gray(cmd))
        if self.dry_run:
            return Status('¡DRY RUN!', '¡DRY RUN!', 0)
        with character_buffered():
            return self._call_command(cmd, **kwargs)

    def format(self, tpl):
        try:
            return self.formatter.vformat(tpl, None, self.context)
        except KeyError as e:
            print(red(f'Missing key {e}'))
            sys.exit(1)

    @property
    def sftp(self):
        if not self._sftp:
            self._sftp = self._client.open_sftp()
        return self._sftp
Exemple #34
0
class GLEAM(Base):
    def __init__(self, filename=None):
        Base.__init__(self, filename)
        # set default spatial reference
        sr = osr.SpatialReference()
        sr.ImportFromEPSG(4326)
        self.srs = sr.ExportToWkt()
        self.doy = 0  # day of the year to retrieve
        self.ssh = None

    def open(self, filename):
        self.nc = nc.Dataset(filename)
        return self.nc is not None

    def close(self):
        self.nc = None

    def getcolrow(self, ds, lon, lat):
        col = (lon - LEFT) / SIZE
        row = (TOP - lat) / SIZE
        return (col, row)

    def extractdate(self, name):
        ''' extract year from filename'''
        #SMroot_2017_GLEAM_v3.2b.nc
        pat = r'(?P<var>\w+)_(?P<year>\d{4})_GLEAM_v(?P<ver>\d\.\d\w)\.nc$'
        m = re.search(pat, name)
        if m is not None:
            year = int(m.group('year'))
            return datetime.date(year, 1, 1)
        return None

    def get_dataset(self, name):
        return self.nc.variables[name]

    def transbox(self, ds, bbox, topix=False, clip=False):

        x1, y1, x2, y2 = bbox

        if clip or topix:

            px1, py1 = self.getcolrow(ds, x1, y1)
            px2, py2 = self.getcolrow(ds, x2, y2)

            if clip:
                px1 = max(0, min(px1, WIDTH - 1))
                px2 = max(0, min(px2, WIDTH - 1))
                py1 = max(0, min(py1, HEIGHT - 1))
                py2 = max(0, min(py2, HEIGHT - 1))

            if topix:
                x1 = int(px1)
                x2 = int(px2)
                y1 = int(py1)
                y2 = int(py2)

        if y1 > y2:
            return (x1, y2, x2, y1)
        else:
            return (x1, y1, x2, y2)

    def get_data(self, ds, bbox=None):

        if isinstance(ds, str):
            ds = self.get_dataset(ds)

        if bbox is None:
            # no bounding box: return entire tile
            data = ds[self.doy]
        else:
            # clip bounding box
            x1, y1, x2, y2 = self.transbox(ds, bbox, topix=True, clip=True)
            data = ds[self.doy, int(x1):int(x2), int(y1):int(y2)]
        # need to transpose data for gdal: y is first index
        return np.transpose(data)

    def iter_data(self, ds, bbox=None):
        ''' yield a tile for every day in this dataset '''
        if isinstance(ds, str):
            ds = self.get_dataset(ds)
        times = self.nc.variables['time']
        self.doy = 0
        for time in times:
            date = DAY0 + timedelta(days=int(time))
            yield (date, self.get_data(ds, bbox))
            self.doy += 1

    def connect(self, host, port=0, timeout=-999):
        if self.ssh is None:
            self.ssh = SSHClient()
        self.ssh.set_missing_host_key_policy(AutoAddPolicy())
        self.ssh.connect(hostname=host,
                         port=port,
                         username=GLEAM_USERNAME,
                         password=GLEAM_PASSWORD)
        self.ftp = self.ssh.open_sftp()
        return self.ftp

    def download(self, filename, folder, overwrite=True):
        print(filename)
        localpath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            os.makedirs(folder)
        if os.path.exists(localpath):
            if not overwrite:
                print(localpath + ' exists')
                return
        self.ftp.get(filename, localpath)

    def download_tile(self, folder, tile, dest, overwrite=True):
        self.ftp.chdir(folder)
        files = self.ftp.listdir()
        for filename in files:
            if tile is None or tile in filename:
                self.download(filename, dest, overwrite)
                break

    def download_dataset(self, name, years, version, dest, overwrite=True):
        if self.ssh is None:
            self.connect(GLEAM_HOST, GLEAM_PORT)
        if not dest.endswith('/'):
            dest += '/'
        for year in years:
            folder = GLEAM_PATH.format(version=version, year=year)
            self.download_tile(folder, name, dest, overwrite)

    def create_tif(self, filename, extent, data, template, etype):

        if os.path.exists(filename):
            os.remove(filename)
        else:
            dirname = os.path.dirname(filename)
            if not os.path.exists(dirname):
                os.makedirs(dirname)
        ysize, xsize = data.shape
        tif = gdal.GetDriverByName('GTiff').Create(filename,
                                                   xsize,
                                                   ysize,
                                                   eType=etype)
        tif.SetProjection(self.srs)
        tif.SetGeoTransform([extent[0], SIZE, 0, extent[3], 0, -SIZE])
        band = tif.GetRasterBand(1)
        band.WriteArray(data)
Exemple #35
0
class SSH:
    def __init__(self,
                 hostname,
                 port=22,
                 username='******',
                 pkey=None,
                 password=None,
                 default_env=None,
                 connect_timeout=10):
        self.stdout = None
        self.client = None
        self.channel = None
        self.sftp = None
        self.eof = 'Spug EOF 2108111926'
        self.already_init = False
        self.default_env = self._make_env_command(default_env)
        self.regex = re.compile(r'Spug EOF 2108111926 (-?\d+)[\r\n]?')
        self.arguments = {
            'hostname':
            hostname,
            'port':
            port,
            'username':
            username,
            'password':
            password,
            'pkey':
            RSAKey.from_private_key(StringIO(pkey))
            if isinstance(pkey, str) else pkey,
            'timeout':
            connect_timeout,
            'banner_timeout':
            30
        }

    @staticmethod
    def generate_key():
        key_obj = StringIO()
        key = RSAKey.generate(2048)
        key.write_private_key(key_obj)
        return key_obj.getvalue(), 'ssh-rsa ' + key.get_base64()

    def get_client(self):
        if self.client is not None:
            return self.client
        self.client = SSHClient()
        self.client.set_missing_host_key_policy(AutoAddPolicy)
        self.client.connect(**self.arguments)
        return self.client

    def ping(self):
        return True

    def add_public_key(self, public_key):
        command = f'mkdir -p -m 700 ~/.ssh && \
        echo {public_key!r} >> ~/.ssh/authorized_keys && \
        chmod 600 ~/.ssh/authorized_keys'

        exit_code, out = self.exec_command_raw(command)
        if exit_code != 0:
            raise Exception(f'add public key error: {out}')

    def exec_command_raw(self, command, environment=None):
        channel = self.client.get_transport().open_session()
        if environment:
            channel.update_environment(environment)
        channel.set_combine_stderr(True)
        channel.exec_command(command)
        code, output = channel.recv_exit_status(), channel.recv(-1)
        return code, self._decode(output)

    def exec_command(self, command, environment=None):
        channel = self._get_channel()
        command = self._handle_command(command, environment)
        channel.send(command)
        out, exit_code = '', -1
        for line in self.stdout:
            match = self.regex.search(line)
            if match:
                exit_code = int(match.group(1))
                line = line[:match.start()]
                out += line
                break
            out += line
        return exit_code, out

    def _win_exec_command_with_stream(self, command, environment=None):
        channel = self.client.get_transport().open_session()
        if environment:
            channel.update_environment(environment)
        channel.set_combine_stderr(True)
        channel.get_pty(width=102)
        channel.exec_command(command)
        stdout = channel.makefile("rb", -1)
        out = stdout.readline()
        while out:
            yield channel.exit_status, self._decode(out)
            out = stdout.readline()
        yield channel.recv_exit_status(), self._decode(out)

    def exec_command_with_stream(self, command, environment=None):
        channel = self._get_channel()
        command = self._handle_command(command, environment)
        channel.send(command)
        exit_code, line = -1, ''
        while True:
            line = self._decode(channel.recv(8196))
            if not line:
                break
            match = self.regex.search(line)
            if match:
                exit_code = int(match.group(1))
                line = line[:match.start()]
                break
            yield exit_code, line
        yield exit_code, line

    def put_file(self, local_path, remote_path):
        sftp = self._get_sftp()
        sftp.put(local_path, remote_path)

    def put_file_by_fl(self, fl, remote_path, callback=None):
        sftp = self._get_sftp()
        sftp.putfo(fl, remote_path, callback=callback)

    def list_dir_attr(self, path):
        sftp = self._get_sftp()
        return sftp.listdir_attr(path)

    def sftp_stat(self, path):
        sftp = self._get_sftp()
        return sftp.stat(path)

    def remove_file(self, path):
        sftp = self._get_sftp()
        sftp.remove(path)

    def _get_channel(self):
        if self.channel:
            return self.channel

        counter = 0
        self.channel = self.client.invoke_shell()
        command = 'set +o zle\nset -o no_nomatch\nexport PS1= && stty -echo\n'
        if self.default_env:
            command += f'{self.default_env}\n'
        command += f'echo {self.eof} $?\n'
        self.channel.send(command.encode())
        while True:
            if self.channel.recv_ready():
                line = self._decode(self.channel.recv(8196))
                if self.regex.search(line):
                    self.stdout = self.channel.makefile('r')
                    break
            elif counter >= 100:
                self.client.close()
                raise Exception('Wait spug response timeout')
            else:
                counter += 1
                time.sleep(0.1)
        return self.channel

    def _get_sftp(self):
        if self.sftp:
            return self.sftp

        self.sftp = self.client.open_sftp()
        return self.sftp

    def _make_env_command(self, environment):
        if not environment:
            return None
        str_envs = []
        for k, v in environment.items():
            k = k.replace('-', '_')
            if isinstance(v, str):
                v = v.replace("'", "'\"'\"'")
            str_envs.append(f"{k}='{v}'")
        str_envs = ' '.join(str_envs)
        return f'export {str_envs}'

    def _handle_command(self, command, environment):
        new_command = commands = ''
        if not self.already_init:
            commands = 'export SPUG_EXEC_FILE=$(mktemp)\n'
            commands += 'trap \'rm -f $SPUG_EXEC_FILE\' EXIT\n'
            self.already_init = True

        env_command = self._make_env_command(environment)
        if env_command:
            new_command += f'{env_command}\n'
        new_command += command
        new_command += f'\necho {self.eof} $?\n'
        b64_command = base64.standard_b64encode(new_command.encode())
        commands += f'echo {b64_command.decode()} | base64 -di > $SPUG_EXEC_FILE\n'
        commands += 'source $SPUG_EXEC_FILE\n'
        return commands

    def _decode(self, content):
        try:
            content = content.decode()
        except UnicodeDecodeError:
            content = content.decode(encoding='GBK', errors='ignore')
        return content

    def __enter__(self):
        self.get_client()
        transport = self.client.get_transport()
        if 'windows' in transport.remote_version.lower():
            self.exec_command = self.exec_command_raw
            self.exec_command_with_stream = self._win_exec_command_with_stream
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.client.close()
        self.client = None
Exemple #36
0
class ParamikoConnection(BaseConnection):
    def __init__(self, **kwargs):
        self._kwargs = kwargs
        self._ssh: Union[SSHClient, None] = None

        self.host = None
        self.user = '******'

        host = kwargs.get('host', None)
        if host:
            splitted = host.split('@')
            self.host = splitted.pop(-1)
            if splitted:
                self.user = splitted[-1]
        self._port = SSH_PORT

    def connect(self):
        logging.getLogger('paramiko').setLevel(logging.ERROR)
        self._ssh = SSHClient()
        self._ssh.set_missing_host_key_policy(AutoAddPolicy())
        self._ssh.load_system_host_keys()
        timeout = self._kwargs.get('timeout', None)
        self._ssh.connect(self.host,
                          self._port,
                          self.user,
                          allow_agent=True,
                          timeout=timeout)

    def close(self):
        pass

    def connected(self):
        pass

    def exec_command(self, cmd):
        transport = self._ssh.get_transport()
        ch = transport.open_session()
        AgentRequestHandler(ch)
        ch.exec_command(cmd)

        stdin = ''
        stdout = ''
        stderr = ''
        size = 1024 * 1024
        # TODO: write stdin to channel
        while True:

            r, w, x = select.select([ch], [], [], 1)

            if len(r):
                if ch in r:
                    while True:
                        data = ch.recv_stderr(size)
                        if not data:
                            break
                        stderr += data.decode(errors='replace')
                        # for line in data.decode(errors='replace').splitlines():
                        #     print(line.strip())
                    while True:
                        data = ch.recv(size)
                        if not data:
                            break
                        stdout += data.decode(errors='replace')
                        # for line in data.decode(errors='replace').splitlines():
                        #     print(line.strip())

            if ch.exit_status_ready():
                break

        return stdin, stdout, stderr, ch.recv_exit_status()

    def put_file(self, local, remote):
        is_file_object = hasattr(local, 'seek') and callable(local.seek)
        sftp = self._ssh.open_sftp()
        if is_file_object:
            sftp.putfo(local, remote)
        else:
            sftp.put(local, remote)
        return remote

    def fetch_file(self, remote, local):
        is_file_object = hasattr(local, 'seek') and callable(local.seek)
        sftp = self._ssh.open_sftp()
        if is_file_object:
            sftp.getfo(remote, local)
        else:
            sftp.get(remote, local)
        return remote

    def put_dir(self, local_path, remote_path):
        assert os.path.isdir(local_path)
        sftp = self._ssh.open_sftp()
        if os.path.basename(local_path):
            strip = os.path.dirname(local_path)
        else:
            strip = os.path.dirname(os.path.dirname(local_path))

        remote_paths = []

        for context, dirs, files in os.walk(local_path):
            rcontext = context.replace(strip, '', 1)
            rcontext = rcontext.replace(os.sep, '/')
            rcontext = rcontext.lstrip('/')
            rcontext = os.path.join(remote_path, rcontext)

            exists = False
            try:
                s = sftp.lstat(rcontext)
                exists = True
            except FileNotFoundError:
                pass

            if not exists:
                sftp.mkdir(rcontext)

            for d in dirs:
                n = os.path.join(rcontext, d)
                exists = False
                try:
                    s = sftp.lstat(rcontext)
                    exists = True
                except FileNotFoundError:
                    pass
                if not exists:
                    sftp.mkdir(n)
            for f in files:
                local_path = os.path.join(context, f)
                n = os.path.join(rcontext, f)
                p = sftp.put(local_path, n)
                remote_paths.append(p)

        return remote_paths

    # def fetch_dir(self):
    #     pass

    def stat(self, remote):
        sftp = self._ssh.open_sftp()
        return sftp.lstat(remote)