示例#1
0
    def __init__(self,
                 session,
                 device_handler,
                 async_mode=False,
                 timeout=30,
                 raise_mode=RaiseMode.NONE):
        """
        *session* is the :class:`~ncclient.transport.Session` instance

        *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance

        *async* specifies whether the request is to be made asynchronously, see :attr:`is_async`

        *timeout* is the timeout for a synchronous request, see :attr:`timeout`

        *raise_mode* specifies the exception raising mode, see :attr:`raise_mode`
        """
        self._session = session
        try:
            for cap in self.DEPENDS:
                self._assert(cap)
        except AttributeError:
            pass
        self._async = async_mode
        self._timeout = timeout
        self._raise_mode = raise_mode
        self._id = uuid4(
        ).urn  # Keeps things simple instead of having a class attr with running ID that has to be locked
        self._listener = RPCReplyListener(session, device_handler)
        self._listener.register(self._id, self)
        self._reply = None
        self._error = None
        self._event = Event()
        self._device_handler = device_handler
        self.logger = SessionLoggerAdapter(logger, {'session': session})
示例#2
0
文件: ssh.py 项目: vitus133/ncclient
    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        # parsing-related, see _parse()
        self._device_handler = device_handler
        self._parsing_state10 = 0
        self._parsing_pos10 = 0
        self._parsing_pos11 = 0
        self._parsing_state11 = 0
        self._expchunksize = 0
        self._curchunksize = 0
        self._inendpos = 0
        self._size_num_list = []
        self._message_list = []
        self._closing = threading.Event()

        self.logger = SessionLoggerAdapter(logger, {'session': self})
示例#3
0
    def __init__(self, session):
        """
        DOM Parser

        :param session: ssh session object
        """
        self._session = session
        self._parsing_pos10 = 0
        self.logger = SessionLoggerAdapter(logger, {'session': self._session})
示例#4
0
文件: rpc.py 项目: siming85/ncclient
    def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE):
        """
        *session* is the :class:`~ncclient.transport.Session` instance

        *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance

        *async* specifies whether the request is to be made asynchronously, see :attr:`is_async`

        *timeout* is the timeout for a synchronous request, see :attr:`timeout`

        *raise_mode* specifies the exception raising mode, see :attr:`raise_mode`
        """
        self._session = session
        try:
            for cap in self.DEPENDS:
                self._assert(cap)
        except AttributeError:
            pass
        self._async = async_mode
        self._timeout = timeout
        self._raise_mode = raise_mode
        self._id = uuid4().urn # Keeps things simple instead of having a class attr with running ID that has to be locked
        self._listener = RPCReplyListener(session, device_handler)
        self._listener.register(self._id, self)
        self._reply = None
        self._error = None
        self._event = Event()
        self._device_handler = device_handler
        self.logger = SessionLoggerAdapter(logger, {'session': session})
示例#5
0
    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        # parsing-related, see _parse()
        self._device_handler = device_handler
        self._parsing_state10 = 0
        self._parsing_pos10 = 0
        self._parsing_pos11 = 0
        self._parsing_state11 = 0
        self._expchunksize = 0
        self._curchunksize = 0
        self._inendpos = 0
        self._size_num_list = []
        self._message_list = []
        self._closing = threading.Event()

        self.logger = SessionLoggerAdapter(logger, {'session': self})
示例#6
0
文件: ssh.py 项目: emitch05/ncclient
    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        self._device_handler = device_handler
        self._message_list = []
        self._closing = threading.Event()
        self.parser = DefaultXMLParser(self)  # SAX or DOM parser

        self.logger = SessionLoggerAdapter(logger, {'session': self})
示例#7
0
 def __init__(self, capabilities):
     Thread.__init__(self)
     self.setDaemon(True)
     self._listeners = set()
     self._lock = Lock()
     self.setName('session')
     self._q = Queue()
     self._notification_q = Queue()
     self._client_capabilities = capabilities
     self._server_capabilities = None  # yet
     self._base = NetconfBase.BASE_10
     self._id = None  # session-id
     self._connected = False  # to be set/cleared by subclass implementation
     self.logger = SessionLoggerAdapter(logger, {'session': self})
     self.logger.debug('%r created: client_capabilities=%r', self,
                       self._client_capabilities)
     self._device_handler = None  # Should be set by child class
示例#8
0
 def __new__(cls, session, device_handler):
     with RPCReplyListener.creation_lock:
         instance = session.get_listener_instance(cls)
         if instance is None:
             instance = object.__new__(cls)
             instance._lock = Lock()
             instance._id2rpc = {}
             instance._device_handler = device_handler
             #instance._pipelined = session.can_pipeline
             session.add_listener(instance)
             instance.logger = SessionLoggerAdapter(logger,
                                                    {'session': session})
         return instance
示例#9
0
 def __init__(self, capabilities):
     Thread.__init__(self)
     self.setDaemon(True)
     self._listeners = set()
     self._lock = Lock()
     self.setName('session')
     self._q = Queue()
     self._notification_q = Queue()
     self._client_capabilities = capabilities
     self._server_capabilities = None # yet
     self._base = NetconfBase.BASE_10
     self._id = None # session-id
     self._connected = False # to be set/cleared by subclass implementation
     self.logger = SessionLoggerAdapter(logger, {'session': self})
     self.logger.debug('%r created: client_capabilities=%r',
                       self, self._client_capabilities)
     self._device_handler = None # Should be set by child class
示例#10
0
文件: ssh.py 项目: vitus133/ncclient
class SSHSession(Session):

    "Implements a :rfc:`4742` NETCONF session over SSH."

    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        # parsing-related, see _parse()
        self._device_handler = device_handler
        self._parsing_state10 = 0
        self._parsing_pos10 = 0
        self._parsing_pos11 = 0
        self._parsing_state11 = 0
        self._expchunksize = 0
        self._curchunksize = 0
        self._inendpos = 0
        self._size_num_list = []
        self._message_list = []
        self._closing = threading.Event()

        self.logger = SessionLoggerAdapter(logger, {'session': self})

    def _dispatch_message(self, raw):
        self.logger.info("Received:\n%s", raw)
        return super(SSHSession, self)._dispatch_message(raw)

    def _parse(self):
        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
        return self._parse10()

    def _parse10(self):
        """Messages are delimited by MSG_DELIM. The buffer could have grown by
        a maximum of BUF_SIZE bytes everytime this method is called. Retains
        state across method calls and if a chunk has been read it will not be
        considered again."""

        self.logger.debug("parsing netconf v1.0")
        buf = self._buffer
        buf.seek(self._parsing_pos10)
        if MSG_DELIM in buf.read().decode('UTF-8'):
            buf.seek(0)
            msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM)
            msg = msg.strip()
            if sys.version < '3':
                self._dispatch_message(msg.encode())
            else:
                self._dispatch_message(msg)
            # create new buffer which contains remaining of old buffer
            self._buffer = StringIO()
            self._buffer.write(remaining.encode())
            self._parsing_pos10 = 0
            if len(remaining) > 0:
                # There could be another entire message in the
                # buffer, so we should try to parse again.
                self.logger.debug(
                    'Trying another round of parsing since there is still data'
                )
                self._parse10()
        else:
            # handle case that MSG_DELIM is split over two chunks
            self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN
            if self._parsing_pos10 < 0:
                self._parsing_pos10 = 0

    def _parse11(self):
        """Messages are split into chunks. Chunks and messages are delimited
        by the regex #RE_NC11_DELIM defined earlier in this file. Each
        time we get called here either a chunk delimiter or an
        end-of-message delimiter should be found iff there is enough
        data. If there is not enough data, we will wait for more. If a
        delimiter is found in the wrong place, a #NetconfFramingError
        will be raised."""

        self.logger.debug("_parse11: starting")

        # suck in whole string that we have (this is what we will work on in
        # this function) and initialize a couple of useful values
        self._buffer.seek(0, os.SEEK_SET)
        data = self._buffer.getvalue()
        data_len = len(data)
        start = 0
        self.logger.debug('_parse11: working with buffer of %d bytes',
                          data_len)
        while True and start < data_len:
            # match to see if we found at least some kind of delimiter
            self.logger.debug(
                '_parse11: matching from %d bytes from start of buffer', start)
            re_result = RE_NC11_DELIM.match(data[start:].decode('utf-8'))
            if not re_result:

                # not found any kind of delimiter just break; this should only
                # ever happen if we just have the first few characters of a
                # message such that we don't yet have a full delimiter
                self.logger.debug('_parse11: no delimiter found, buffer="%s"',
                                  data[start:].decode())
                break

            # save useful variables for reuse
            re_start = re_result.start()
            re_end = re_result.end()
            self.logger.debug('_parse11: regular expression start=%d, end=%d',
                              re_start, re_end)

            # If the regex doesn't start at the beginning of the buffer,
            # we're in trouble, so throw an error
            if re_start != 0:
                raise NetconfFramingError(
                    '_parse11: delimiter not at start of match buffer',
                    data[start:])

            if re_result.group(2):
                # we've found the end of the message, need to form up
                # whole message, save back remainder (if any) to buffer
                # and dispatch the message
                start += re_end
                message = ''.join(self._message_list)
                self._message_list = []
                self.logger.debug('_parse11: found end of message delimiter')
                self._dispatch_message(message)
                break

            elif re_result.group(1):
                # we've found a chunk delimiter, and group(2) is the digit
                # string that will tell us how many bytes past the end of
                # where it was found that we need to have available to
                # save the next chunk off
                self.logger.debug('_parse11: found chunk delimiter')
                digits = int(re_result.group(1))
                self.logger.debug('_parse11: chunk size %d bytes', digits)
                if (data_len - start) >= (re_end + digits):
                    # we have enough data for the chunk
                    fragment = textify(data[start + re_end:start + re_end +
                                            digits])
                    self._message_list.append(fragment)
                    start += re_end + digits
                    self.logger.debug('_parse11: appending %d bytes', digits)
                    self.logger.debug('_parse11: fragment = "%s"', fragment)
                else:
                    # we don't have enough bytes, just break out for now
                    # after updating start pointer to start of new chunk
                    start += re_start
                    self.logger.debug(
                        '_parse11: not enough data for chunk yet')
                    self.logger.debug('_parse11: setting start to %d', start)
                    break

        # Now out of the loop, need to see if we need to save back any content
        if start > 0:
            self.logger.debug(
                '_parse11: saving back rest of message after %d bytes, original size %d',
                start, data_len)
            self._buffer = StringIO(data[start:])
            if start < data_len:
                self.logger.debug(
                    '_parse11: still have data, may have another full message!'
                )
                self._parse11()
        self.logger.debug('_parse11: ending')

    def load_known_hosts(self, filename=None):
        """Load host keys from an openssh :file:`known_hosts`-style file. Can
        be called multiple times.

        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
        """

        if filename is None:
            filename = os.path.expanduser('~/.ssh/known_hosts')
            try:
                self._host_keys.load(filename)
            except IOError:
                # for windows
                filename = os.path.expanduser('~/ssh/known_hosts')
                try:
                    self._host_keys.load(filename)
                except IOError:
                    pass
        else:
            self._host_keys.load(filename)

    def close(self):
        self._closing.set()
        if self._transport.is_active():
            self._transport.close()

        # Wait for the transport thread to close.
        while self.is_alive() and (self is not threading.current_thread()):
            self.join(10)

        if self._channel:
            self._channel.close()
        self._channel = None
        self._connected = False

    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
    def connect(self,
                host,
                port=PORT_NETCONF_DEFAULT,
                timeout=None,
                unknown_host_cb=default_unknown_host_cb,
                username=None,
                password=None,
                key_filename=None,
                allow_agent=True,
                hostkey_verify=True,
                hostkey_b64=None,
                look_for_keys=True,
                ssh_config=None,
                sock_fd=None):
        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.

        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.

        *host* is the hostname or IP address to connect to

        *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 (PORT_SSH_DEFAULT) so this may need to be specified

        *timeout* is an optional timeout for socket connect

        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)

        *username* is the username to use for SSH authentication

        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it

        *key_filename* is a filename where a the private key to be used can be found

        *allow_agent* enables querying SSH agent (if found) for keys

        *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts

        *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan)

        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)

        *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`).

        *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number
        """
        if not (host or sock_fd):
            raise SSHError("Missing host or socket fd")

        self._host = host

        # Optionally, parse .ssh/config
        config = {}
        if ssh_config is True:
            ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config"
        if ssh_config is not None:
            config = paramiko.SSHConfig()
            config.parse(open(os.path.expanduser(ssh_config)))

            # Save default Paramiko SSH port so it can be reverted
            paramiko_default_ssh_port = paramiko.config.SSH_PORT

            # Change the default SSH port to the port specified by the user so expand_variables
            # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT)

            paramiko.config.SSH_PORT = port

            config = config.lookup(host)

            # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port
            # back to the default
            paramiko.config.SSH_PORT = paramiko_default_ssh_port

            host = config.get("hostname", host)
            if username is None:
                username = config.get("user")
            if key_filename is None:
                key_filename = config.get("identityfile")
            if hostkey_verify:
                userknownhostsfile = config.get("userknownhostsfile")
                if userknownhostsfile:
                    self.load_known_hosts(
                        os.path.expanduser(userknownhostsfile))

        if username is None:
            username = getpass.getuser()

        if sock_fd is None:
            if config.get("proxycommand"):
                self.logger.debug("Configuring Proxy. %s",
                                  config.get("proxycommand"))
                sock = paramiko.proxy.ProxyCommand(config.get("proxycommand"))
            else:
                for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
                                              socket.SOCK_STREAM):
                    af, socktype, proto, canonname, sa = res
                    try:
                        sock = socket.socket(af, socktype, proto)
                        sock.settimeout(timeout)
                    except socket.error:
                        continue
                    try:
                        sock.connect(sa)
                    except socket.error:
                        sock.close()
                        continue
                    break
                else:
                    raise SSHError("Could not open socket to %s:%s" %
                                   (host, port))
        else:
            if sys.version_info[0] < 3:
                s = socket.fromfd(int(sock_fd), socket.AF_INET,
                                  socket.SOCK_STREAM)
                sock = socket.socket(socket.AF_INET,
                                     socket.SOCK_STREAM,
                                     _sock=s)
            else:
                sock = socket.fromfd(int(sock_fd), socket.AF_INET,
                                     socket.SOCK_STREAM)
            sock.settimeout(timeout)

        self._transport = paramiko.Transport(sock)
        self._transport.set_log_channel(logger.name)
        if config.get("compression") == 'yes':
            self._transport.use_compression()

        if hostkey_b64:
            # If we need to connect with a specific hostkey, negotiate for only its type
            hostkey_obj = None
            for key_cls in [
                    paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey,
                    paramiko.ECDSAKey
            ]:
                try:
                    hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64))
                except paramiko.SSHException:
                    # Not a key of this type - try the next
                    pass
            if not hostkey_obj:
                # We've tried all known host key types and haven't found a suitable one to use - bail
                raise SSHError(
                    "Couldn't find suitable paramiko key class for host key %s"
                    % hostkey_b64)
            self._transport._preferred_keys = [hostkey_obj.get_name()]
        elif self._host_keys:
            # Else set preferred host keys to those we possess for the host
            # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation)
            if port == PORT_SSH_DEFAULT:
                known_hosts_lookup = host
            else:
                known_hosts_lookup = '[%s]:%s' % (host, port)
            known_host_keys_for_this_host = self._host_keys.lookup(
                known_hosts_lookup)
            if known_host_keys_for_this_host:
                self._transport._preferred_keys = [
                    x.key.get_name()
                    for x in known_host_keys_for_this_host._entries
                ]

        # Connect
        try:
            self._transport.start_client()
        except paramiko.SSHException as e:
            raise SSHError('Negotiation failed: %s' % e)

        server_key_obj = self._transport.get_remote_server_key()
        fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint()))

        if hostkey_verify:
            is_known_host = False

            # For looking up entries for nonstandard (22) ssh ports in known_hosts
            # we enclose host in brackets and append port number
            if port == PORT_SSH_DEFAULT:
                known_hosts_lookup = host
            else:
                known_hosts_lookup = '[%s]:%s' % (host, port)

            if hostkey_b64:
                # If hostkey specified, remote host /must/ use that hostkey
                if (hostkey_obj.get_name() == server_key_obj.get_name()
                        and hostkey_obj.asbytes() == server_key_obj.asbytes()):
                    is_known_host = True
            else:
                # Check known_hosts
                is_known_host = self._host_keys.check(known_hosts_lookup,
                                                      server_key_obj)

            if not is_known_host and not unknown_host_cb(host, fingerprint):
                raise SSHUnknownHostError(known_hosts_lookup, fingerprint)

        # Authenticating with our private key/identity
        if key_filename is None:
            key_filenames = []
        elif isinstance(key_filename, (str, bytes)):
            key_filenames = [key_filename]
        else:
            key_filenames = key_filename

        self._auth(username, password, key_filenames, allow_agent,
                   look_for_keys)

        self._connected = True  # there was no error authenticating
        self._closing.clear()

        # TODO: leopoul: Review, test, and if needed rewrite this part
        subsystem_names = self._device_handler.get_ssh_subsystem_names()
        for subname in subsystem_names:
            self._channel = self._transport.open_session()
            self._channel_id = self._channel.get_id()
            channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id))
            self._channel.set_name(channel_name)
            try:
                self._channel.invoke_subsystem(subname)
            except paramiko.SSHException as e:
                self.logger.info("%s (subsystem request rejected)", e)
                handle_exception = self._device_handler.handle_connection_exceptions(
                    self)
                # Ignore the exception, since we continue to try the different
                # subsystem names until we find one that can connect.
                # have to handle exception for each vendor here
                if not handle_exception:
                    continue
            self._channel_name = self._channel.get_name()
            self._post_connect()
            return
        raise SSHError(
            "Could not open connection, possibly due to unacceptable"
            " SSH subsystem name.")

    def _auth(self, username, password, key_filenames, allow_agent,
              look_for_keys):
        saved_exception = None

        for key_filename in key_filenames:
            for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey):
                try:
                    key = cls.from_private_key_file(key_filename, password)
                    self.logger.debug("Trying key %s from %s",
                                      hexlify(key.get_fingerprint()),
                                      key_filename)
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        if allow_agent:
            for key in paramiko.Agent().get_keys():
                try:
                    self.logger.debug("Trying SSH agent key %s",
                                      hexlify(key.get_fingerprint()))
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        keyfiles = []
        if look_for_keys:
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))
            # look in ~/ssh/ for windows users:
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))

        for cls, filename in keyfiles:
            try:
                key = cls.from_private_key_file(filename, password)
                self.logger.debug("Trying discovered key %s in %s",
                                  hexlify(key.get_fingerprint()), filename)
                self._transport.auth_publickey(username, key)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if password is not None:
            try:
                self._transport.auth_password(username, password)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if saved_exception is not None:
            # need pep-3134 to do this right
            raise AuthenticationError(repr(saved_exception))

        raise AuthenticationError("No authentication methods available")

    def run(self):
        chan = self._channel
        q = self._q

        def start_delim(data_len):
            return '\n#%s\n' % (data_len)

        try:
            s = selectors.DefaultSelector()
            s.register(chan, selectors.EVENT_READ)
            self.logger.debug('selector type = %s', s.__class__.__name__)
            while True:

                # Will wakeup evey TICK seconds to check if something
                # to send, more quickly if something to read (due to
                # select returning chan in readable list).
                events = s.select(timeout=TICK)
                if events:
                    data = chan.recv(BUF_SIZE)
                    if data:
                        self._buffer.seek(0, os.SEEK_END)
                        self._buffer.write(data)
                        if self._base == NetconfBase.BASE_11:
                            self._parse11()
                        else:
                            self._parse10()
                    elif self._closing.is_set():
                        # End of session, expected
                        break
                    else:
                        # End of session, unexpected
                        raise SessionCloseError(self._buffer.getvalue())
                if not q.empty() and chan.send_ready():
                    self.logger.debug("Sending message")
                    data = q.get()
                    if self._base == NetconfBase.BASE_11:
                        data = "%s%s%s" % (start_delim(
                            len(data)), data, END_DELIM)
                    else:
                        data = "%s%s" % (data, MSG_DELIM)
                    self.logger.info("Sending:\n%s", data)
                    while data:
                        n = chan.send(data)
                        if n <= 0:
                            raise SessionCloseError(self._buffer.getvalue(),
                                                    data)
                        data = data[n:]
        except Exception as e:
            self.logger.debug("Broke out of main loop, error=%r", e)
            self._dispatch_error(e)
            self.close()

    @property
    def host(self):
        """Host this session is connected to, or None if not connected."""
        if hasattr(self, '_host'):
            return self._host
        return None

    @property
    def transport(self):
        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
        return self._transport
示例#11
0
class DefaultXMLParser(object):
    def __init__(self, session):
        """
        DOM Parser

        :param session: ssh session object
        """
        self._session = session
        self._parsing_pos10 = 0
        self.logger = SessionLoggerAdapter(logger, {'session': self._session})

    def parse(self, data):
        """
        parse incoming RPC response from networking device.

        :param data: incoming RPC data from device
        :return: None
        """
        if data:
            self._session._buffer.seek(0, os.SEEK_END)
            self._session._buffer.write(data)
            if self._session._base == NetconfBase.BASE_11:
                self._parse11()
            else:
                self._parse10()

    def _parse10(self):
        """Messages are delimited by MSG_DELIM. The buffer could have grown by
        a maximum of BUF_SIZE bytes everytime this method is called. Retains
        state across method calls and if a chunk has been read it will not be
        considered again."""

        self.logger.debug("parsing netconf v1.0")
        buf = self._session._buffer
        buf.seek(self._parsing_pos10)
        if MSG_DELIM in buf.read().decode('UTF-8'):
            buf.seek(0)
            msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM)
            msg = msg.strip()
            if sys.version < '3':
                self._session._dispatch_message(msg.encode())
            else:
                self._session._dispatch_message(msg)
            self._session._buffer = StringIO()
            self._parsing_pos10 = 0
            if len(remaining.strip()) > 0:
                # There could be another entire message in the
                # buffer, so we should try to parse again.
                if type(self._session.parser) != DefaultXMLParser:
                    self.logger.debug('send remaining data to SAX parser')
                    self._session.parser.parse(remaining.encode())
                else:
                    self.logger.debug(
                        'Trying another round of parsing since there is still data'
                    )
                    self._session._buffer.write(remaining.encode())
                    self._parse10()
        else:
            # handle case that MSG_DELIM is split over two chunks
            self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN
            if self._parsing_pos10 < 0:
                self._parsing_pos10 = 0

    def _parse11(self):
        """Messages are split into chunks. Chunks and messages are delimited
        by the regex #RE_NC11_DELIM defined earlier in this file. Each
        time we get called here either a chunk delimiter or an
        end-of-message delimiter should be found iff there is enough
        data. If there is not enough data, we will wait for more. If a
        delimiter is found in the wrong place, a #NetconfFramingError
        will be raised."""

        self.logger.debug("_parse11: starting")

        # suck in whole string that we have (this is what we will work on in
        # this function) and initialize a couple of useful values
        self._session._buffer.seek(0, os.SEEK_SET)
        data = self._session._buffer.getvalue()
        data_len = len(data)
        start = 0
        self.logger.debug('_parse11: working with buffer of %d bytes',
                          data_len)
        while True and start < data_len:
            # match to see if we found at least some kind of delimiter
            self.logger.debug(
                '_parse11: matching from %d bytes from start of buffer', start)
            re_result = RE_NC11_DELIM.match(data[start:].decode('utf-8'))
            if not re_result:

                # not found any kind of delimiter just break; this should only
                # ever happen if we just have the first few characters of a
                # message such that we don't yet have a full delimiter
                self.logger.debug('_parse11: no delimiter found, buffer="%s"',
                                  data[start:].decode())
                break

            # save useful variables for reuse
            re_start = re_result.start()
            re_end = re_result.end()
            self.logger.debug('_parse11: regular expression start=%d, end=%d',
                              re_start, re_end)

            # If the regex doesn't start at the beginning of the buffer,
            # we're in trouble, so throw an error
            if re_start != 0:
                raise NetconfFramingError(
                    '_parse11: delimiter not at start of match buffer',
                    data[start:])

            if re_result.group(2):
                # we've found the end of the message, need to form up
                # whole message, save back remainder (if any) to buffer
                # and dispatch the message
                start += re_end
                message = ''.join(self._session._message_list)
                self._session._message_list = []
                self.logger.debug('_parse11: found end of message delimiter')
                self._session._dispatch_message(message)
                break

            elif re_result.group(1):
                # we've found a chunk delimiter, and group(2) is the digit
                # string that will tell us how many bytes past the end of
                # where it was found that we need to have available to
                # save the next chunk off
                self.logger.debug('_parse11: found chunk delimiter')
                digits = int(re_result.group(1))
                self.logger.debug('_parse11: chunk size %d bytes', digits)
                if (data_len - start) >= (re_end + digits):
                    # we have enough data for the chunk
                    fragment = textify(data[start + re_end:start + re_end +
                                            digits])
                    self._session._message_list.append(fragment)
                    start += re_end + digits
                    self.logger.debug('_parse11: appending %d bytes', digits)
                    self.logger.debug('_parse11: fragment = "%s"', fragment)
                else:
                    # we don't have enough bytes, just break out for now
                    # after updating start pointer to start of new chunk
                    start += re_start
                    self.logger.debug(
                        '_parse11: not enough data for chunk yet')
                    self.logger.debug('_parse11: setting start to %d', start)
                    break

        # Now out of the loop, need to see if we need to save back any content
        if start > 0:
            self.logger.debug(
                '_parse11: saving back rest of message after %d bytes, original size %d',
                start, data_len)
            self._session._buffer = StringIO(data[start:])
            if start < data_len:
                self.logger.debug(
                    '_parse11: still have data, may have another full message!'
                )
                self._parse11()
        self.logger.debug('_parse11: ending')
示例#12
0
文件: ssh.py 项目: emitch05/ncclient
class SSHSession(Session):

    "Implements a :rfc:`4742` NETCONF session over SSH."

    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        self._device_handler = device_handler
        self._message_list = []
        self._closing = threading.Event()
        self.parser = DefaultXMLParser(self)  # SAX or DOM parser

        self.logger = SessionLoggerAdapter(logger, {'session': self})

    def _dispatch_message(self, raw):
        # Provide basic response message
        self.logger.info("Received message from host")
        # Provide complete response from host at debug log level
        self.logger.debug("Received:\n%s", raw)
        return super(SSHSession, self)._dispatch_message(raw)

    def _parse(self):
        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
        return self.parser._parse10()

    def load_known_hosts(self, filename=None):
        """Load host keys from an openssh :file:`known_hosts`-style file. Can
        be called multiple times.

        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
        """

        if filename is None:
            filename = os.path.expanduser('~/.ssh/known_hosts')
            try:
                self._host_keys.load(filename)
            except IOError:
                # for windows
                filename = os.path.expanduser('~/ssh/known_hosts')
                try:
                    self._host_keys.load(filename)
                except IOError:
                    pass
        else:
            self._host_keys.load(filename)

    def close(self):
        self._closing.set()
        if self._transport.is_active():
            self._transport.close()

        # Wait for the transport thread to close.
        while self.is_alive() and (self is not threading.current_thread()):
            self.join(10)

        if self._channel:
            self._channel.close()
        self._channel = None
        self._connected = False

    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
    def connect(self,
                host,
                port=PORT_NETCONF_DEFAULT,
                timeout=None,
                unknown_host_cb=default_unknown_host_cb,
                username=None,
                password=None,
                key_filename=None,
                allow_agent=True,
                hostkey_verify=True,
                hostkey_b64=None,
                look_for_keys=True,
                ssh_config=None,
                sock_fd=None,
                bind_addr=None):
        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.

        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.

        *host* is the hostname or IP address to connect to

        *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 so this may need to be specified

        *timeout* is an optional timeout for socket connect

        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)

        *username* is the username to use for SSH authentication

        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it

        *key_filename* is a filename where a the private key to be used can be found

        *allow_agent* enables querying SSH agent (if found) for keys

        *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts

        *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan)

        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)

        *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`).

        *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number

        *bind_addr* is a (local) source IP address to use, must be reachable from the remote device.
        """
        if not (host or sock_fd):
            raise SSHError("Missing host or socket fd")

        self._host = host

        # Optionally, parse .ssh/config
        config = {}
        if ssh_config is True:
            ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config"
        if ssh_config is not None:
            config = paramiko.SSHConfig()
            with open(os.path.expanduser(ssh_config)) as ssh_config_file_obj:
                config.parse(ssh_config_file_obj)

            # Save default Paramiko SSH port so it can be reverted
            paramiko_default_ssh_port = paramiko.config.SSH_PORT

            # Change the default SSH port to the port specified by the user so expand_variables
            # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT)

            paramiko.config.SSH_PORT = port

            config = config.lookup(host)

            # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port
            # back to the default
            paramiko.config.SSH_PORT = paramiko_default_ssh_port

            host = config.get("hostname", host)
            if username is None:
                username = config.get("user")
            if key_filename is None:
                key_filename = config.get("identityfile")
            if hostkey_verify:
                userknownhostsfile = config.get("userknownhostsfile")
                if userknownhostsfile:
                    self.load_known_hosts(
                        os.path.expanduser(userknownhostsfile))
            if timeout is None:
                timeout = config.get("connecttimeout")
                if timeout:
                    timeout = int(timeout)

        if username is None:
            username = getpass.getuser()

        if sock_fd is None:
            proxycommand = config.get("proxycommand")
            if proxycommand:
                self.logger.debug("Configuring Proxy. %s", proxycommand)
                if not isinstance(proxycommand, six.string_types):
                    proxycommand = [
                        os.path.expanduser(elem) for elem in proxycommand
                    ]
                else:
                    proxycommand = os.path.expanduser(proxycommand)
                sock = paramiko.proxy.ProxyCommand(proxycommand)
            else:
                for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
                                              socket.SOCK_STREAM):
                    af, socktype, proto, canonname, sa = res
                    try:
                        sock = socket.socket(af, socktype, proto)
                        sock.settimeout(timeout)
                    except socket.error:
                        continue
                    try:
                        if bind_addr:
                            sock.bind((bind_addr, 0))
                        sock.connect(sa)
                    except socket.error:
                        sock.close()
                        continue
                    break
                else:
                    raise SSHError("Could not open socket to %s:%s" %
                                   (host, port))
        else:
            if sys.version_info[0] < 3:
                s = socket.fromfd(int(sock_fd), socket.AF_INET,
                                  socket.SOCK_STREAM)
                sock = socket.socket(socket.AF_INET,
                                     socket.SOCK_STREAM,
                                     _sock=s)
            else:
                sock = socket.fromfd(int(sock_fd), socket.AF_INET,
                                     socket.SOCK_STREAM)
            sock.settimeout(timeout)

        self._transport = paramiko.Transport(sock)
        self._transport.set_log_channel(logger.name)
        if config.get("compression") == 'yes':
            self._transport.use_compression()

        if hostkey_b64:
            # If we need to connect with a specific hostkey, negotiate for only its type
            hostkey_obj = None
            for key_cls in [
                    paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey,
                    paramiko.ECDSAKey
            ]:
                try:
                    hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64))
                except paramiko.SSHException:
                    # Not a key of this type - try the next
                    pass
            if not hostkey_obj:
                # We've tried all known host key types and haven't found a suitable one to use - bail
                raise SSHError(
                    "Couldn't find suitable paramiko key class for host key %s"
                    % hostkey_b64)
            self._transport._preferred_keys = [hostkey_obj.get_name()]
        elif self._host_keys:
            # Else set preferred host keys to those we possess for the host
            # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation)
            known_host_keys_for_this_host = self._host_keys.lookup(host) or {}
            host_port = '[%s]:%s' % (host, port)
            known_host_keys_for_this_host.update(
                self._host_keys.lookup(host_port) or {})
            if known_host_keys_for_this_host:
                self._transport._preferred_keys = [
                    x.key.get_name()
                    for x in known_host_keys_for_this_host._entries
                ]

        # Connect
        try:
            self._transport.start_client()
        except paramiko.SSHException as e:
            raise SSHError('Negotiation failed: %s' % e)

        server_key_obj = self._transport.get_remote_server_key()
        fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint()))

        if hostkey_verify:
            is_known_host = False

            # For looking up entries for nonstandard (22) ssh ports in known_hosts
            # we enclose host in brackets and append port number
            known_hosts_lookups = [host, '[%s]:%s' % (host, port)]

            if hostkey_b64:
                # If hostkey specified, remote host /must/ use that hostkey
                if (hostkey_obj.get_name() == server_key_obj.get_name()
                        and hostkey_obj.asbytes() == server_key_obj.asbytes()):
                    is_known_host = True
            else:
                # Check known_hosts
                is_known_host = any(
                    self._host_keys.check(lookup, server_key_obj)
                    for lookup in known_hosts_lookups)

            if not is_known_host and not unknown_host_cb(host, fingerprint):
                raise SSHUnknownHostError(known_hosts_lookups[0], fingerprint)

        # Authenticating with our private key/identity
        if key_filename is None:
            key_filenames = []
        elif isinstance(key_filename, (str, bytes)):
            key_filenames = [key_filename]
        else:
            key_filenames = key_filename

        self._auth(username, password, key_filenames, allow_agent,
                   look_for_keys)

        self._connected = True  # there was no error authenticating
        self._closing.clear()

        # TODO: leopoul: Review, test, and if needed rewrite this part
        subsystem_names = self._device_handler.get_ssh_subsystem_names()
        for subname in subsystem_names:
            self._channel = self._transport.open_session()
            self._channel_id = self._channel.get_id()
            channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id))
            self._channel.set_name(channel_name)
            try:
                self._channel.invoke_subsystem(subname)
            except paramiko.SSHException as e:
                self.logger.info("%s (subsystem request rejected)", e)
                handle_exception = self._device_handler.handle_connection_exceptions(
                    self)
                # Ignore the exception, since we continue to try the different
                # subsystem names until we find one that can connect.
                # have to handle exception for each vendor here
                if not handle_exception:
                    continue
            self._channel_name = self._channel.get_name()
            self._post_connect()
            # for further upcoming RPC responses, vendor can chose their
            # choice of parser. Say DOM or SAX
            self.parser = self._device_handler.get_xml_parser(self)
            return
        raise SSHError(
            "Could not open connection, possibly due to unacceptable"
            " SSH subsystem name.")

    def _auth(self, username, password, key_filenames, allow_agent,
              look_for_keys):
        saved_exception = None

        for key_filename in key_filenames:
            for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey,
                        paramiko.Ed25519Key):
                try:
                    key = cls.from_private_key_file(key_filename, password)
                    self.logger.debug("Trying key %s from %s",
                                      hexlify(key.get_fingerprint()),
                                      key_filename)
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        if allow_agent:
            for key in paramiko.Agent().get_keys():
                try:
                    self.logger.debug("Trying SSH agent key %s",
                                      hexlify(key.get_fingerprint()))
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        keyfiles = []
        if look_for_keys:
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))
            # look in ~/ssh/ for windows users:
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))

        for cls, filename in keyfiles:
            try:
                key = cls.from_private_key_file(filename, password)
                self.logger.debug("Trying discovered key %s in %s",
                                  hexlify(key.get_fingerprint()), filename)
                self._transport.auth_publickey(username, key)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if password is not None:
            try:
                self._transport.auth_password(username, password)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if saved_exception is not None:
            # need pep-3134 to do this right
            raise AuthenticationError(repr(saved_exception))

        raise AuthenticationError("No authentication methods available")

    def run(self):
        chan = self._channel
        q = self._q

        def start_delim(data_len):
            return '\n#%s\n' % (data_len)

        try:
            s = selectors.DefaultSelector()
            s.register(chan, selectors.EVENT_READ)
            self.logger.debug('selector type = %s', s.__class__.__name__)
            while True:

                # Will wakeup evey TICK seconds to check if something
                # to send, more quickly if something to read (due to
                # select returning chan in readable list).
                events = s.select(timeout=TICK)
                if events:
                    data = chan.recv(BUF_SIZE)
                    if data:
                        try:
                            self.parser.parse(data)
                        except SAXFilterXMLNotFoundError:
                            self.logger.debug(
                                'switching from sax to dom parsing')
                            self.parser = DefaultXMLParser(self)
                            self.parser.parse(data)
                    elif self._closing.is_set():
                        # End of session, expected
                        break
                    else:
                        # End of session, unexpected
                        raise SessionCloseError(self._buffer.getvalue())
                if not q.empty() and chan.send_ready():
                    self.logger.debug("Sending message")
                    data = q.get()
                    if self._base == NetconfBase.BASE_11:
                        data = "%s%s%s" % (start_delim(
                            len(data)), data, END_DELIM)
                    else:
                        data = "%s%s" % (data, MSG_DELIM)
                    self.logger.info("Sending:\n%s", data)
                    while data:
                        n = chan.send(data)
                        if n <= 0:
                            raise SessionCloseError(self._buffer.getvalue(),
                                                    data)
                        data = data[n:]
        except Exception as e:
            self.logger.debug("Broke out of main loop, error=%r", e)
            self._dispatch_error(e)
            self.close()

    @property
    def host(self):
        """Host this session is connected to, or None if not connected."""
        if hasattr(self, '_host'):
            return self._host
        return None

    @property
    def transport(self):
        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
        return self._transport
示例#13
0
class SSHSession(Session):

    "Implements a :rfc:`4742` NETCONF session over SSH."

    def __init__(self, device_handler):
        capabilities = Capabilities(device_handler.get_capabilities())
        Session.__init__(self, capabilities)
        self._host = None
        self._host_keys = paramiko.HostKeys()
        self._transport = None
        self._connected = False
        self._channel = None
        self._channel_id = None
        self._channel_name = None
        self._buffer = StringIO()
        # parsing-related, see _parse()
        self._device_handler = device_handler
        self._parsing_state10 = 0
        self._parsing_pos10 = 0
        self._parsing_pos11 = 0
        self._parsing_state11 = 0
        self._expchunksize = 0
        self._curchunksize = 0
        self._inendpos = 0
        self._size_num_list = []
        self._message_list = []
        self._closing = threading.Event()

        self.logger = SessionLoggerAdapter(logger, {'session': self})

    def _dispatch_message(self, raw):
        self.logger.info("Received:\n%s", raw)
        return super(SSHSession, self)._dispatch_message(raw)

    def _parse(self):
        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
        return self._parse10()

    def _parse10(self):

        """Messages are delimited by MSG_DELIM. The buffer could have grown by
        a maximum of BUF_SIZE bytes everytime this method is called. Retains
        state across method calls and if a chunk has been read it will not be
        considered again."""

        self.logger.debug("parsing netconf v1.0")
        buf = self._buffer
        buf.seek(self._parsing_pos10)
        if MSG_DELIM in buf.read().decode('UTF-8'):
            buf.seek(0)
            msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM)
            msg = msg.strip()
            if sys.version < '3':
                self._dispatch_message(msg.encode())
            else:
                self._dispatch_message(msg)
            # create new buffer which contains remaining of old buffer
            self._buffer = StringIO()
            self._buffer.write(remaining.encode())
            self._parsing_pos10 = 0
            if len(remaining) > 0:
                # There could be another entire message in the
                # buffer, so we should try to parse again.
                self.logger.debug('Trying another round of parsing since there is still data')
                self._parse10()
        else:
            # handle case that MSG_DELIM is split over two chunks
            self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN
            if self._parsing_pos10 < 0:
                self._parsing_pos10 = 0

    def _parse11(self):

        """Messages are split into chunks. Chunks and messages are delimited
        by the regex #RE_NC11_DELIM defined earlier in this file. Each
        time we get called here either a chunk delimiter or an
        end-of-message delimiter should be found iff there is enough
        data. If there is not enough data, we will wait for more. If a
        delimiter is found in the wrong place, a #NetconfFramingError
        will be raised."""

        self.logger.debug("_parse11: starting")

        # suck in whole string that we have (this is what we will work on in
        # this function) and initialize a couple of useful values
        self._buffer.seek(0, os.SEEK_SET)
        data = self._buffer.getvalue()
        data_len = len(data)
        start = 0
        self.logger.debug('_parse11: working with buffer of %d bytes', data_len)
        while True and start < data_len:
            # match to see if we found at least some kind of delimiter
            self.logger.debug('_parse11: matching from %d bytes from start of buffer', start)
            re_result = RE_NC11_DELIM.match(data[start:])
            if not re_result:

                # not found any kind of delimiter just break; this should only
                # ever happen if we just have the first few characters of a
                # message such that we don't yet have a full delimiter
                self.logger.debug('_parse11: no delimiter found, buffer="%s"', data[start:].decode())
                break

            # save useful variables for reuse
            re_start = re_result.start()
            re_end = re_result.end()
            self.logger.debug('_parse11: regular expression start=%d, end=%d', re_start, re_end)

            # If the regex doesn't start at the beginning of the buffer,
            # we're in trouble, so throw an error
            if re_start != 0:
                raise NetconfFramingError('_parse11: delimiter not at start of match buffer', data[start:])

            if re_result.group(2):
                # we've found the end of the message, need to form up
                # whole message, save back remainder (if any) to buffer
                # and dispatch the message
                start += re_end
                message = ''.join(self._message_list)
                self._message_list = []
                self.logger.debug('_parse11: found end of message delimiter')
                self._dispatch_message(message)
                break

            elif re_result.group(1):
                # we've found a chunk delimiter, and group(2) is the digit
                # string that will tell us how many bytes past the end of
                # where it was found that we need to have available to
                # save the next chunk off
                self.logger.debug('_parse11: found chunk delimiter')
                digits = int(re_result.group(1))
                self.logger.debug('_parse11: chunk size %d bytes', digits)
                if (data_len-start) >= (re_end + digits):
                    # we have enough data for the chunk
                    fragment = textify(data[start+re_end:start+re_end+digits])
                    self._message_list.append(fragment)
                    start += re_end + digits
                    self.logger.debug('_parse11: appending %d bytes', digits)
                    self.logger.debug('_parse11: fragment = "%s"', fragment)
                else:
                    # we don't have enough bytes, just break out for now
                    # after updating start pointer to start of new chunk
                    start += re_start
                    self.logger.debug('_parse11: not enough data for chunk yet')
                    self.logger.debug('_parse11: setting start to %d', start)
                    break

        # Now out of the loop, need to see if we need to save back any content
        if start > 0:
            self.logger.debug(
                '_parse11: saving back rest of message after %d bytes, original size %d',
                start, data_len)
            self._buffer = StringIO(data[start:])
            if start < data_len:
                self.logger.debug('_parse11: still have data, may have another full message!')
                self._parse11()
        self.logger.debug('_parse11: ending')

    def load_known_hosts(self, filename=None):

        """Load host keys from an openssh :file:`known_hosts`-style file. Can
        be called multiple times.

        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
        """

        if filename is None:
            filename = os.path.expanduser('~/.ssh/known_hosts')
            try:
                self._host_keys.load(filename)
            except IOError:
                # for windows
                filename = os.path.expanduser('~/ssh/known_hosts')
                try:
                    self._host_keys.load(filename)
                except IOError:
                    pass
        else:
            self._host_keys.load(filename)

    def close(self):
        self._closing.set()
        if self._transport.is_active():
            self._transport.close()

        # Wait for the transport thread to close.
        while self.is_alive() and (self is not threading.current_thread()):
            self.join(10)

        if self._channel:
            self._channel.close()
        self._channel = None
        self._connected = False

    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
    def connect(
            self,
            host,
            port                = PORT_NETCONF_DEFAULT,
            timeout             = None,
            unknown_host_cb     = default_unknown_host_cb,
            username            = None,
            password            = None,
            key_filename        = None,
            allow_agent         = True,
            hostkey_verify      = True,
            hostkey_b64         = None,
            look_for_keys       = True,
            ssh_config          = None,
            sock_fd             = None):

        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.

        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.

        *host* is the hostname or IP address to connect to

        *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 (PORT_SSH_DEFAULT) so this may need to be specified

        *timeout* is an optional timeout for socket connect

        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)

        *username* is the username to use for SSH authentication

        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it

        *key_filename* is a filename where a the private key to be used can be found

        *allow_agent* enables querying SSH agent (if found) for keys

        *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts

        *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan)

        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)

        *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`).

        *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number
        """
        if not (host or sock_fd):
            raise SSHError("Missing host or socket fd")

        self._host = host

        # Optionally, parse .ssh/config
        config = {}
        if ssh_config is True:
            ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config"
        if ssh_config is not None:
            config = paramiko.SSHConfig()
            config.parse(open(os.path.expanduser(ssh_config)))
            
            # Save default Paramiko SSH port so it can be reverted
            paramiko_default_ssh_port = paramiko.config.SSH_PORT
            
            # Change the default SSH port to the port specified by the user so expand_variables 
            # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT)

            paramiko.config.SSH_PORT = port

            config = config.lookup(host)

            # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port
            # back to the default
            paramiko.config.SSH_PORT = paramiko_default_ssh_port

            host = config.get("hostname", host)
            if username is None:
                username = config.get("user")
            if key_filename is None:
                key_filename = config.get("identityfile")
            if hostkey_verify:
                userknownhostsfile = config.get("userknownhostsfile")
                if userknownhostsfile:
                    self.load_known_hosts(os.path.expanduser(userknownhostsfile))
            if timeout is None:
                timeout = config.get("connecttimeout")
                if timeout:
                    timeout = int(timeout)

        if username is None:
            username = getpass.getuser()

        if sock_fd is None:
            proxycommand = config.get("proxycommand")
            if proxycommand:
                self.logger.debug("Configuring Proxy. %s", proxycommand)
                if not isinstance(proxycommand, six.string_types):
                  proxycommand = [os.path.expanduser(elem) for elem in proxycommand]
                else:
                  proxycommand = os.path.expanduser(proxycommand)
                sock = paramiko.proxy.ProxyCommand(proxycommand)
            else:
                for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
                    af, socktype, proto, canonname, sa = res
                    try:
                        sock = socket.socket(af, socktype, proto)
                        sock.settimeout(timeout)
                    except socket.error:
                        continue
                    try:
                        sock.connect(sa)
                    except socket.error:
                        sock.close()
                        continue
                    break
                else:
                    raise SSHError("Could not open socket to %s:%s" % (host, port))
        else:
            if sys.version_info[0] < 3:
                s = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM)
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, _sock=s)
            else:
                sock = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(timeout)

        self._transport = paramiko.Transport(sock)
        self._transport.set_log_channel(logger.name)
        if config.get("compression") == 'yes':
            self._transport.use_compression()

        if hostkey_b64:
            # If we need to connect with a specific hostkey, negotiate for only its type
            hostkey_obj = None
            for key_cls in [paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]:
                try:
                    hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64))
                except paramiko.SSHException:
                    # Not a key of this type - try the next
                    pass
            if not hostkey_obj:
                # We've tried all known host key types and haven't found a suitable one to use - bail
                raise SSHError("Couldn't find suitable paramiko key class for host key %s" % hostkey_b64)
            self._transport._preferred_keys = [hostkey_obj.get_name()]
        elif self._host_keys:
            # Else set preferred host keys to those we possess for the host
            # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation)
            if port == PORT_SSH_DEFAULT:
                known_hosts_lookup = host
            else:
                known_hosts_lookup = '[%s]:%s' % (host, port)
            known_host_keys_for_this_host = self._host_keys.lookup(known_hosts_lookup)
            if known_host_keys_for_this_host:
                self._transport._preferred_keys = [x.key.get_name() for x in known_host_keys_for_this_host._entries]

        # Connect
        try:
            self._transport.start_client()
        except paramiko.SSHException as e:
            raise SSHError('Negotiation failed: %s' % e)

        server_key_obj = self._transport.get_remote_server_key()
        fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint()))

        if hostkey_verify:
            is_known_host = False

            # For looking up entries for nonstandard (22) ssh ports in known_hosts
            # we enclose host in brackets and append port number
            if port == PORT_SSH_DEFAULT:
                known_hosts_lookup = host
            else:
                known_hosts_lookup = '[%s]:%s' % (host, port)

            if hostkey_b64:
                # If hostkey specified, remote host /must/ use that hostkey
                if(hostkey_obj.get_name() == server_key_obj.get_name() and hostkey_obj.asbytes() == server_key_obj.asbytes()):
                    is_known_host = True
            else:
                # Check known_hosts
                is_known_host = self._host_keys.check(known_hosts_lookup, server_key_obj)

            if not is_known_host and not unknown_host_cb(host, fingerprint):
                raise SSHUnknownHostError(known_hosts_lookup, fingerprint)

        # Authenticating with our private key/identity
        if key_filename is None:
            key_filenames = []
        elif isinstance(key_filename, (str, bytes)):
            key_filenames = [key_filename]
        else:
            key_filenames = key_filename

        self._auth(username, password, key_filenames, allow_agent, look_for_keys)

        self._connected = True      # there was no error authenticating
        self._closing.clear()

        # TODO: leopoul: Review, test, and if needed rewrite this part
        subsystem_names = self._device_handler.get_ssh_subsystem_names()
        for subname in subsystem_names:
            self._channel = self._transport.open_session()
            self._channel_id = self._channel.get_id()
            channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id))
            self._channel.set_name(channel_name)
            try:
                self._channel.invoke_subsystem(subname)
            except paramiko.SSHException as e:
                self.logger.info("%s (subsystem request rejected)", e)
                handle_exception = self._device_handler.handle_connection_exceptions(self)
                # Ignore the exception, since we continue to try the different
                # subsystem names until we find one that can connect.
                # have to handle exception for each vendor here
                if not handle_exception:
                    continue
            self._channel_name = self._channel.get_name()
            self._post_connect()
            return
        raise SSHError("Could not open connection, possibly due to unacceptable"
                       " SSH subsystem name.")

    def _auth(self, username, password, key_filenames, allow_agent,
              look_for_keys):
        saved_exception = None

        for key_filename in key_filenames:
            for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey):
                try:
                    key = cls.from_private_key_file(key_filename, password)
                    self.logger.debug("Trying key %s from %s",
                                      hexlify(key.get_fingerprint()),
                                      key_filename)
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        if allow_agent:
            for key in paramiko.Agent().get_keys():
                try:
                    self.logger.debug("Trying SSH agent key %s",
                                      hexlify(key.get_fingerprint()))
                    self._transport.auth_publickey(username, key)
                    return
                except Exception as e:
                    saved_exception = e
                    self.logger.debug(e)

        keyfiles = []
        if look_for_keys:
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))
            # look in ~/ssh/ for windows users:
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
            ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa")
            if os.path.isfile(rsa_key):
                keyfiles.append((paramiko.RSAKey, rsa_key))
            if os.path.isfile(dsa_key):
                keyfiles.append((paramiko.DSSKey, dsa_key))
            if os.path.isfile(ecdsa_key):
                keyfiles.append((paramiko.ECDSAKey, ecdsa_key))

        for cls, filename in keyfiles:
            try:
                key = cls.from_private_key_file(filename, password)
                self.logger.debug("Trying discovered key %s in %s",
                                  hexlify(key.get_fingerprint()), filename)
                self._transport.auth_publickey(username, key)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if password is not None:
            try:
                self._transport.auth_password(username, password)
                return
            except Exception as e:
                saved_exception = e
                self.logger.debug(e)

        if saved_exception is not None:
            # need pep-3134 to do this right
            raise AuthenticationError(repr(saved_exception))

        raise AuthenticationError("No authentication methods available")

    def run(self):
        chan = self._channel
        q = self._q

        def start_delim(data_len): return '\n#%s\n' % (data_len)

        try:
            s = selectors.DefaultSelector()
            s.register(chan, selectors.EVENT_READ)
            self.logger.debug('selector type = %s', s.__class__.__name__)
            while True:

                # Will wakeup evey TICK seconds to check if something
                # to send, more quickly if something to read (due to
                # select returning chan in readable list).
                events = s.select(timeout=TICK)
                if events:
                    data = chan.recv(BUF_SIZE)
                    if data:
                        self._buffer.seek(0, os.SEEK_END)
                        self._buffer.write(data)
                        if self._base == NetconfBase.BASE_11:
                            self._parse11()
                        else:
                            self._parse10()
                    elif self._closing.is_set():
                        # End of session, expected
                        break
                    else:
                        # End of session, unexpected
                        raise SessionCloseError(self._buffer.getvalue())
                if not q.empty() and chan.send_ready():
                    self.logger.debug("Sending message")
                    data = q.get()
                    if self._base == NetconfBase.BASE_11:
                        data = "%s%s%s" % (start_delim(len(data)), data, END_DELIM)
                    else:
                        data = "%s%s" % (data, MSG_DELIM)
                    self.logger.info("Sending:\n%s", data)
                    while data:
                        n = chan.send(data)
                        if n <= 0:
                            raise SessionCloseError(self._buffer.getvalue(), data)
                        data = data[n:]
        except Exception as e:
            self.logger.debug("Broke out of main loop, error=%r", e)
            self._dispatch_error(e)
            self.close()

    @property
    def host(self):
        """Host this session is connected to, or None if not connected."""
        if hasattr(self, '_host'):
            return self._host
        return None

    @property
    def transport(self):
        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
        return self._transport
示例#14
0
class Session(Thread):

    "Base class for use by transport protocol implementations."

    def __init__(self, capabilities):
        Thread.__init__(self)
        self.setDaemon(True)
        self._listeners = set()
        self._lock = Lock()
        self.setName('session')
        self._q = Queue()
        self._notification_q = Queue()
        self._client_capabilities = capabilities
        self._server_capabilities = None # yet
        self._base = NetconfBase.BASE_10
        self._id = None # session-id
        self._connected = False # to be set/cleared by subclass implementation
        self.logger = SessionLoggerAdapter(logger, {'session': self})
        self.logger.debug('%r created: client_capabilities=%r',
                          self, self._client_capabilities)
        self._device_handler = None # Should be set by child class

    def _dispatch_message(self, raw):
        try:
            root = parse_root(raw)
        except Exception as e:
            device_handled_raw=self._device_handler.handle_raw_dispatch(raw)
            if isinstance(device_handled_raw, str):
                root = parse_root(device_handled_raw)
            elif isinstance(device_handled_raw, Exception):
                self._dispatch_error(device_handled_raw)
                return
            else:
                self.logger.error('error parsing dispatch message: %s', e)
                return
        with self._lock:
            listeners = list(self._listeners)
        for l in listeners:
            self.logger.debug('dispatching message to %r: %s', l, raw)
            l.callback(root, raw) # no try-except; fail loudly if you must!

    def _dispatch_error(self, err):
        with self._lock:
            listeners = list(self._listeners)
        for l in listeners:
            self.logger.debug('dispatching error to %r', l)
            try: # here we can be more considerate with catching exceptions
                l.errback(err)
            except Exception as e:
                self.logger.warning('error dispatching to %r: %r', l, e)

    def _post_connect(self):
        "Greeting stuff"
        init_event = Event()
        error = [None] # so that err_cb can bind error[0]. just how it is.
        # callbacks
        def ok_cb(id, capabilities):
            self._id = id
            self._server_capabilities = capabilities
            init_event.set()
        def err_cb(err):
            error[0] = err
            init_event.set()
        self.add_listener(NotificationHandler(self._notification_q))
        listener = HelloHandler(ok_cb, err_cb)
        self.add_listener(listener)
        self.send(HelloHandler.build(self._client_capabilities, self._device_handler))
        self.logger.debug('starting main loop')
        self.start()
        # we expect server's hello message, if server doesn't responds in 60 seconds raise exception
        init_event.wait(60)
        if not init_event.is_set():
            raise SessionError("Capability exchange timed out")
        # received hello message or an error happened
        self.remove_listener(listener)
        if error[0]:
            raise error[0]
        #if ':base:1.0' not in self.server_capabilities:
        #    raise MissingCapabilityError(':base:1.0')
        if 'urn:ietf:params:netconf:base:1.1' in self._server_capabilities and 'urn:ietf:params:netconf:base:1.1' in self._client_capabilities:
            self.logger.debug("After 'hello' message selecting netconf:base:1.1 for encoding")
            self._base = NetconfBase.BASE_11
        self.logger.info('initialized: session-id=%s | server_capabilities=%s',
                         self._id, self._server_capabilities)

    def add_listener(self, listener):
        """Register a listener that will be notified of incoming messages and
        errors.

        :type listener: :class:`SessionListener`
        """
        self.logger.debug('installing listener %r', listener)
        if not isinstance(listener, SessionListener):
            raise SessionError("Listener must be a SessionListener type")
        with self._lock:
            self._listeners.add(listener)

    def remove_listener(self, listener):
        """Unregister some listener; ignore if the listener was never
        registered.

        :type listener: :class:`SessionListener`
        """
        self.logger.debug('discarding listener %r', listener)
        with self._lock:
            self._listeners.discard(listener)

    def get_listener_instance(self, cls):
        """If a listener of the specified type is registered, returns the
        instance.

        :type cls: :class:`SessionListener`
        """
        with self._lock:
            for listener in self._listeners:
                if isinstance(listener, cls):
                    return listener

    def connect(self, *args, **kwds): # subclass implements
        raise NotImplementedError

    def run(self): # subclass implements
        raise NotImplementedError

    def send(self, message):
        """Send the supplied *message* (xml string) to NETCONF server."""
        if not self.connected:
            raise TransportError('Not connected to NETCONF server')
        self.logger.debug('queueing %s', message)
        self._q.put(message)

    def scp(self):
        raise NotImplementedError
    ### Properties

    def take_notification(self, block, timeout):
        try:
            return self._notification_q.get(block, timeout)
        except Empty:
            return None

    @property
    def connected(self):
        "Connection status of the session."
        return self._connected

    @property
    def client_capabilities(self):
        "Client's :class:`Capabilities`"
        return self._client_capabilities

    @property
    def server_capabilities(self):
        "Server's :class:`Capabilities`"
        return self._server_capabilities

    @property
    def id(self):
        """A string representing the `session-id`. If the session has not been initialized it will be `None`"""
        return self._id
示例#15
0
文件: rpc.py 项目: siming85/ncclient
class RPC(object):

    """Base class for all operations, directly corresponding to *rpc* requests. Handles making the request, and taking delivery of the reply."""

    DEPENDS = []
    """Subclasses can specify their dependencies on capabilities as a list of URI's or abbreviated names, e.g. ':writable-running'. These are verified at the time of instantiation. If the capability is not available, :exc:`MissingCapabilityError` is raised."""

    REPLY_CLS = RPCReply
    "By default :class:`RPCReply`. Subclasses can specify a :class:`RPCReply` subclass."


    def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE):
        """
        *session* is the :class:`~ncclient.transport.Session` instance

        *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance

        *async* specifies whether the request is to be made asynchronously, see :attr:`is_async`

        *timeout* is the timeout for a synchronous request, see :attr:`timeout`

        *raise_mode* specifies the exception raising mode, see :attr:`raise_mode`
        """
        self._session = session
        try:
            for cap in self.DEPENDS:
                self._assert(cap)
        except AttributeError:
            pass
        self._async = async_mode
        self._timeout = timeout
        self._raise_mode = raise_mode
        self._id = uuid4().urn # Keeps things simple instead of having a class attr with running ID that has to be locked
        self._listener = RPCReplyListener(session, device_handler)
        self._listener.register(self._id, self)
        self._reply = None
        self._error = None
        self._event = Event()
        self._device_handler = device_handler
        self.logger = SessionLoggerAdapter(logger, {'session': session})


    def _wrap(self, subele):
        # internal use
        ele = new_ele("rpc", {"message-id": self._id},
                      **self._device_handler.get_xml_extra_prefix_kwargs())
        ele.append(subele)
        #print to_xml(ele)
        return to_xml(ele)

    def _request(self, op):
        """Implementations of :meth:`request` call this method to send the request and process the reply.

        In synchronous mode, blocks until the reply is received and returns :class:`RPCReply`. Depending on the :attr:`raise_mode` a `rpc-error` element in the reply may lead to an :exc:`RPCError` exception.

        In asynchronous mode, returns immediately, returning `self`. The :attr:`event` attribute will be set when the reply has been received (see :attr:`reply`) or an error occured (see :attr:`error`).

        *op* is the operation to be requested as an :class:`~xml.etree.ElementTree.Element`
        """
        self.logger.info('Requesting %r', self.__class__.__name__)
        req = self._wrap(op)
        self._session.send(req)
        if self._async:
            self.logger.debug('Async request, returning %r', self)
            return self
        else:
            self.logger.debug('Sync request, will wait for timeout=%r', self._timeout)
            self._event.wait(self._timeout)
            if self._event.isSet():
                if self._error:
                    # Error that prevented reply delivery
                    raise self._error
                self._reply.parse()
                if self._reply.error is not None and not self._device_handler.is_rpc_error_exempt(self._reply.error.message):
                    # <rpc-error>'s [ RPCError ]

                    if self._raise_mode == RaiseMode.ALL or (self._raise_mode == RaiseMode.ERRORS and self._reply.error.severity == "error"):
                        errlist = []
                        errors = self._reply.errors
                        if len(errors) > 1:
                            raise RPCError(to_ele(self._reply._raw), errs=errors)
                        else:
                            raise self._reply.error
                if self._device_handler.transform_reply():
                    return NCElement(self._reply, self._device_handler.transform_reply())
                else:
                    return self._reply
            else:
                raise TimeoutExpiredError('ncclient timed out while waiting for an rpc reply.')

    def request(self):
        """Subclasses must implement this method. Typically only the request needs to be built as an
        :class:`~xml.etree.ElementTree.Element` and everything else can be handed off to
        :meth:`_request`."""
        pass

    def _assert(self, capability):
        """Subclasses can use this method to verify that a capability is available with the NETCONF
        server, before making a request that requires it. A :exc:`MissingCapabilityError` will be
        raised if the capability is not available."""
        if capability not in self._session.server_capabilities:
            raise MissingCapabilityError('Server does not support [%s]' % capability)

    def deliver_reply(self, raw):
        # internal use
        self._reply = self.REPLY_CLS(raw)
        self._event.set()

    def deliver_error(self, err):
        # internal use
        self._error = err
        self._event.set()

    @property
    def reply(self):
        ":class:`RPCReply` element if reply has been received or `None`"
        return self._reply

    @property
    def error(self):
        """:exc:`Exception` type if an error occured or `None`.

        .. note::
            This represents an error which prevented a reply from being received. An *rpc-error*
            does not fall in that category -- see `RPCReply` for that.
        """
        return self._error

    @property
    def id(self):
        "The *message-id* for this RPC."
        return self._id

    @property
    def session(self):
        "The `~ncclient.transport.Session` object associated with this RPC."
        return self._session

    @property
    def event(self):
        """:class:`~threading.Event` that is set when reply has been received or when an error preventing
        delivery of the reply occurs.
        """
        return self._event

    def __set_async(self, async_mode=True):
        self._async = async_mode
        if async_mode and not self._session.can_pipeline:
            raise UserWarning('Asynchronous mode not supported for this device/session')

    def __set_raise_mode(self, mode):
        assert(mode in (RaiseMode.NONE, RaiseMode.ERRORS, RaiseMode.ALL))
        self._raise_mode = mode

    def __set_timeout(self, timeout):
        self._timeout = timeout

    raise_mode = property(fget=lambda self: self._raise_mode, fset=__set_raise_mode)
    """Depending on this exception raising mode, an `rpc-error` in the reply may be raised as an :exc:`RPCError` exception. Valid values are the constants defined in :class:`RaiseMode`. """

    is_async = property(fget=lambda self: self._async, fset=__set_async)
    """Specifies whether this RPC will be / was requested asynchronously. By default RPC's are synchronous."""

    timeout = property(fget=lambda self: self._timeout, fset=__set_timeout)
    """Timeout in seconds for synchronous waiting defining how long the RPC request will block on a reply before raising :exc:`TimeoutExpiredError`.
示例#16
0
class RPC(object):
    """Base class for all operations, directly corresponding to *rpc* requests. Handles making the request, and taking delivery of the reply."""

    DEPENDS = []
    """Subclasses can specify their dependencies on capabilities as a list of URI's or abbreviated names, e.g. ':writable-running'. These are verified at the time of instantiation. If the capability is not available, :exc:`MissingCapabilityError` is raised."""

    REPLY_CLS = RPCReply
    "By default :class:`RPCReply`. Subclasses can specify a :class:`RPCReply` subclass."

    def __init__(self,
                 session,
                 device_handler,
                 async_mode=False,
                 timeout=30,
                 raise_mode=RaiseMode.NONE,
                 huge_tree=False):
        """
        *session* is the :class:`~ncclient.transport.Session` instance

        *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance

        *async* specifies whether the request is to be made asynchronously, see :attr:`is_async`

        *timeout* is the timeout for a synchronous request, see :attr:`timeout`

        *raise_mode* specifies the exception raising mode, see :attr:`raise_mode`

        *huge_tree* parse xml with huge_tree support (e.g. for large text config retrieval), see :attr:`huge_tree`
        """
        self._session = session
        try:
            for cap in self.DEPENDS:
                self._assert(cap)
        except AttributeError:
            pass
        self._async = async_mode
        self._timeout = timeout
        self._raise_mode = raise_mode
        self._huge_tree = huge_tree
        self._id = uuid4(
        ).urn  # Keeps things simple instead of having a class attr with running ID that has to be locked
        self._listener = RPCReplyListener(session, device_handler)
        self._listener.register(self._id, self)
        self._reply = None
        self._error = None
        self._event = Event()
        self._device_handler = device_handler
        self.logger = SessionLoggerAdapter(logger, {'session': session})

    def _wrap(self, subele):
        # internal use
        ele = new_ele("rpc", {"message-id": self._id},
                      **self._device_handler.get_xml_extra_prefix_kwargs())
        ele.append(subele)
        return to_xml(ele)

    def _request(self, op):
        """Implementations of :meth:`request` call this method to send the request and process the reply.

        In synchronous mode, blocks until the reply is received and returns :class:`RPCReply`. Depending on the :attr:`raise_mode` a `rpc-error` element in the reply may lead to an :exc:`RPCError` exception.

        In asynchronous mode, returns immediately, returning `self`. The :attr:`event` attribute will be set when the reply has been received (see :attr:`reply`) or an error occured (see :attr:`error`).

        *op* is the operation to be requested as an :class:`~xml.etree.ElementTree.Element`
        """
        self.logger.info('Requesting %r', self.__class__.__name__)
        req = self._wrap(op)
        self._session.send(req)
        if self._async:
            self.logger.debug('Async request, returning %r', self)
            return self
        else:
            self.logger.debug('Sync request, will wait for timeout=%r',
                              self._timeout)
            self._event.wait(self._timeout)
            if self._event.isSet():
                if self._error:
                    # Error that prevented reply delivery
                    raise self._error
                self._reply.parse()
                if self._reply.error is not None and not self._device_handler.is_rpc_error_exempt(
                        self._reply.error.message):
                    # <rpc-error>'s [ RPCError ]

                    if self._raise_mode == RaiseMode.ALL or (
                            self._raise_mode == RaiseMode.ERRORS
                            and self._reply.error.severity == "error"):
                        errlist = []
                        errors = self._reply.errors
                        if len(errors) > 1:
                            raise RPCError(to_ele(self._reply._raw),
                                           errs=errors)
                        else:
                            raise self._reply.error
                if self._device_handler.transform_reply():
                    return NCElement(self._reply,
                                     self._device_handler.transform_reply(),
                                     huge_tree=self._huge_tree)
                else:
                    return self._reply
            else:
                raise TimeoutExpiredError(
                    'ncclient timed out while waiting for an rpc reply.')

    def request(self):
        """Subclasses must implement this method. Typically only the request needs to be built as an
        :class:`~xml.etree.ElementTree.Element` and everything else can be handed off to
        :meth:`_request`."""
        pass

    def _assert(self, capability):
        """Subclasses can use this method to verify that a capability is available with the NETCONF
        server, before making a request that requires it. A :exc:`MissingCapabilityError` will be
        raised if the capability is not available."""
        if capability not in self._session.server_capabilities:
            raise MissingCapabilityError('Server does not support [%s]' %
                                         capability)

    def deliver_reply(self, raw):
        # internal use
        self._reply = self.REPLY_CLS(raw, huge_tree=self._huge_tree)

        # Set the reply_parsing_error transform outside the constructor, to keep compatibility for
        # third party reply classes outside of ncclient
        self._reply.set_parsing_error_transform(
            self._device_handler.reply_parsing_error_transform(self.REPLY_CLS))

        self._event.set()

    def deliver_error(self, err):
        # internal use
        self._error = err
        self._event.set()

    @property
    def reply(self):
        ":class:`RPCReply` element if reply has been received or `None`"
        return self._reply

    @property
    def error(self):
        """:exc:`Exception` type if an error occured or `None`.

        .. note::
            This represents an error which prevented a reply from being received. An *rpc-error*
            does not fall in that category -- see `RPCReply` for that.
        """
        return self._error

    @property
    def id(self):
        "The *message-id* for this RPC."
        return self._id

    @property
    def session(self):
        "The `~ncclient.transport.Session` object associated with this RPC."
        return self._session

    @property
    def event(self):
        """:class:`~threading.Event` that is set when reply has been received or when an error preventing
        delivery of the reply occurs.
        """
        return self._event

    def __set_async(self, async_mode=True):
        self._async = async_mode
        if async_mode and not self._session.can_pipeline:
            raise UserWarning(
                'Asynchronous mode not supported for this device/session')

    def __set_raise_mode(self, mode):
        assert (mode in (RaiseMode.NONE, RaiseMode.ERRORS, RaiseMode.ALL))
        self._raise_mode = mode

    def __set_timeout(self, timeout):
        self._timeout = timeout

    raise_mode = property(fget=lambda self: self._raise_mode,
                          fset=__set_raise_mode)
    """Depending on this exception raising mode, an `rpc-error` in the reply may be raised as an :exc:`RPCError` exception. Valid values are the constants defined in :class:`RaiseMode`. """

    is_async = property(fget=lambda self: self._async, fset=__set_async)
    """Specifies whether this RPC will be / was requested asynchronously. By default RPC's are synchronous."""

    timeout = property(fget=lambda self: self._timeout, fset=__set_timeout)
    """Timeout in seconds for synchronous waiting defining how long the RPC request will block on a reply before raising :exc:`TimeoutExpiredError`.

    Irrelevant for asynchronous usage.
    """

    @property
    def huge_tree(self):
        """Whether `huge_tree` support for XML parsing of RPC replies is enabled (default=False)"""
        return self._huge_tree

    @huge_tree.setter
    def huge_tree(self, x):
        self._huge_tree = x
示例#17
0
class Session(Thread):

    "Base class for use by transport protocol implementations."

    def __init__(self, capabilities):
        Thread.__init__(self)
        self.setDaemon(True)
        self._listeners = set()
        self._lock = Lock()
        self.setName('session')
        self._q = Queue()
        self._notification_q = Queue()
        self._client_capabilities = capabilities
        self._server_capabilities = None  # yet
        self._base = NetconfBase.BASE_10
        self._id = None  # session-id
        self._connected = False  # to be set/cleared by subclass implementation
        self.logger = SessionLoggerAdapter(logger, {'session': self})
        self.logger.debug('%r created: client_capabilities=%r', self,
                          self._client_capabilities)
        self._device_handler = None  # Should be set by child class

    def _dispatch_message(self, raw):
        try:
            root = parse_root(raw)
        except Exception as e:
            device_handled_raw = self._device_handler.handle_raw_dispatch(raw)
            if isinstance(device_handled_raw, str):
                root = parse_root(device_handled_raw)
            elif isinstance(device_handled_raw, Exception):
                self._dispatch_error(device_handled_raw)
                return
            else:
                self.logger.error('error parsing dispatch message: %s', e)
                return
        self.logger.debug('dispatching message to different listeners: %s',
                          raw)
        with self._lock:
            listeners = list(self._listeners)
        for l in listeners:
            self.logger.debug('dispatching message to listener: %r', l)
            l.callback(root, raw)  # no try-except; fail loudly if you must!

    def _dispatch_error(self, err):
        with self._lock:
            listeners = list(self._listeners)
        for l in listeners:
            self.logger.debug('dispatching error to %r', l)
            try:  # here we can be more considerate with catching exceptions
                l.errback(err)
            except Exception as e:
                self.logger.warning('error dispatching to %r: %r', l, e)

    def _post_connect(self):
        "Greeting stuff"
        init_event = Event()
        error = [None]  # so that err_cb can bind error[0]. just how it is.

        # callbacks
        def ok_cb(id, capabilities):
            self._id = id
            self._server_capabilities = capabilities
            init_event.set()

        def err_cb(err):
            error[0] = err
            init_event.set()

        self.add_listener(NotificationHandler(self._notification_q))
        listener = HelloHandler(ok_cb, err_cb)
        self.add_listener(listener)
        self.send(
            HelloHandler.build(self._client_capabilities,
                               self._device_handler))
        self.logger.debug('starting main loop')
        self.start()
        # we expect server's hello message, if server doesn't responds in 60 seconds raise exception
        init_event.wait(60)
        if not init_event.is_set():
            raise SessionError("Capability exchange timed out")
        # received hello message or an error happened
        self.remove_listener(listener)
        if error[0]:
            raise error[0]
        #if ':base:1.0' not in self.server_capabilities:
        #    raise MissingCapabilityError(':base:1.0')
        if 'urn:ietf:params:netconf:base:1.1' in self._server_capabilities and 'urn:ietf:params:netconf:base:1.1' in self._client_capabilities:
            self.logger.debug(
                "After 'hello' message selecting netconf:base:1.1 for encoding"
            )
            self._base = NetconfBase.BASE_11
        self.logger.info('initialized: session-id=%s | server_capabilities=%s',
                         self._id, self._server_capabilities)

    def add_listener(self, listener):
        """Register a listener that will be notified of incoming messages and
        errors.

        :type listener: :class:`SessionListener`
        """
        self.logger.debug('installing listener %r', listener)
        if not isinstance(listener, SessionListener):
            raise SessionError("Listener must be a SessionListener type")
        with self._lock:
            self._listeners.add(listener)

    def remove_listener(self, listener):
        """Unregister some listener; ignore if the listener was never
        registered.

        :type listener: :class:`SessionListener`
        """
        self.logger.debug('discarding listener %r', listener)
        with self._lock:
            self._listeners.discard(listener)

    def get_listener_instance(self, cls):
        """If a listener of the specified type is registered, returns the
        instance.

        :type cls: :class:`SessionListener`
        """
        with self._lock:
            for listener in self._listeners:
                if isinstance(listener, cls):
                    return listener

    def connect(self, *args, **kwds):  # subclass implements
        raise NotImplementedError

    def run(self):  # subclass implements
        raise NotImplementedError

    def send(self, message):
        """Send the supplied *message* (xml string) to NETCONF server."""
        if not self.connected:
            raise TransportError('Not connected to NETCONF server')
        self.logger.debug('queueing %s', message)
        self._q.put(message)

    def scp(self):
        raise NotImplementedError

    ### Properties

    def take_notification(self, block, timeout):
        try:
            return self._notification_q.get(block, timeout)
        except Empty:
            return None

    @property
    def connected(self):
        "Connection status of the session."
        return self._connected

    @property
    def client_capabilities(self):
        "Client's :class:`Capabilities`"
        return self._client_capabilities

    @property
    def server_capabilities(self):
        "Server's :class:`Capabilities`"
        return self._server_capabilities

    @property
    def id(self):
        """A string representing the `session-id`. If the session has not been initialized it will be `None`"""
        return self._id