def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE): """ *session* is the :class:`~ncclient.transport.Session` instance *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance *async* specifies whether the request is to be made asynchronously, see :attr:`is_async` *timeout* is the timeout for a synchronous request, see :attr:`timeout` *raise_mode* specifies the exception raising mode, see :attr:`raise_mode` """ self._session = session try: for cap in self.DEPENDS: self._assert(cap) except AttributeError: pass self._async = async_mode self._timeout = timeout self._raise_mode = raise_mode self._id = uuid4( ).urn # Keeps things simple instead of having a class attr with running ID that has to be locked self._listener = RPCReplyListener(session, device_handler) self._listener.register(self._id, self) self._reply = None self._error = None self._event = Event() self._device_handler = device_handler self.logger = SessionLoggerAdapter(logger, {'session': session})
def __init__(self, device_handler): capabilities = Capabilities(device_handler.get_capabilities()) Session.__init__(self, capabilities) self._host = None self._host_keys = paramiko.HostKeys() self._transport = None self._connected = False self._channel = None self._channel_id = None self._channel_name = None self._buffer = StringIO() # parsing-related, see _parse() self._device_handler = device_handler self._parsing_state10 = 0 self._parsing_pos10 = 0 self._parsing_pos11 = 0 self._parsing_state11 = 0 self._expchunksize = 0 self._curchunksize = 0 self._inendpos = 0 self._size_num_list = [] self._message_list = [] self._closing = threading.Event() self.logger = SessionLoggerAdapter(logger, {'session': self})
def __init__(self, session): """ DOM Parser :param session: ssh session object """ self._session = session self._parsing_pos10 = 0 self.logger = SessionLoggerAdapter(logger, {'session': self._session})
def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE): """ *session* is the :class:`~ncclient.transport.Session` instance *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance *async* specifies whether the request is to be made asynchronously, see :attr:`is_async` *timeout* is the timeout for a synchronous request, see :attr:`timeout` *raise_mode* specifies the exception raising mode, see :attr:`raise_mode` """ self._session = session try: for cap in self.DEPENDS: self._assert(cap) except AttributeError: pass self._async = async_mode self._timeout = timeout self._raise_mode = raise_mode self._id = uuid4().urn # Keeps things simple instead of having a class attr with running ID that has to be locked self._listener = RPCReplyListener(session, device_handler) self._listener.register(self._id, self) self._reply = None self._error = None self._event = Event() self._device_handler = device_handler self.logger = SessionLoggerAdapter(logger, {'session': session})
def __init__(self, device_handler): capabilities = Capabilities(device_handler.get_capabilities()) Session.__init__(self, capabilities) self._host = None self._host_keys = paramiko.HostKeys() self._transport = None self._connected = False self._channel = None self._channel_id = None self._channel_name = None self._buffer = StringIO() self._device_handler = device_handler self._message_list = [] self._closing = threading.Event() self.parser = DefaultXMLParser(self) # SAX or DOM parser self.logger = SessionLoggerAdapter(logger, {'session': self})
def __init__(self, capabilities): Thread.__init__(self) self.setDaemon(True) self._listeners = set() self._lock = Lock() self.setName('session') self._q = Queue() self._notification_q = Queue() self._client_capabilities = capabilities self._server_capabilities = None # yet self._base = NetconfBase.BASE_10 self._id = None # session-id self._connected = False # to be set/cleared by subclass implementation self.logger = SessionLoggerAdapter(logger, {'session': self}) self.logger.debug('%r created: client_capabilities=%r', self, self._client_capabilities) self._device_handler = None # Should be set by child class
def __new__(cls, session, device_handler): with RPCReplyListener.creation_lock: instance = session.get_listener_instance(cls) if instance is None: instance = object.__new__(cls) instance._lock = Lock() instance._id2rpc = {} instance._device_handler = device_handler #instance._pipelined = session.can_pipeline session.add_listener(instance) instance.logger = SessionLoggerAdapter(logger, {'session': session}) return instance
class SSHSession(Session): "Implements a :rfc:`4742` NETCONF session over SSH." def __init__(self, device_handler): capabilities = Capabilities(device_handler.get_capabilities()) Session.__init__(self, capabilities) self._host = None self._host_keys = paramiko.HostKeys() self._transport = None self._connected = False self._channel = None self._channel_id = None self._channel_name = None self._buffer = StringIO() # parsing-related, see _parse() self._device_handler = device_handler self._parsing_state10 = 0 self._parsing_pos10 = 0 self._parsing_pos11 = 0 self._parsing_state11 = 0 self._expchunksize = 0 self._curchunksize = 0 self._inendpos = 0 self._size_num_list = [] self._message_list = [] self._closing = threading.Event() self.logger = SessionLoggerAdapter(logger, {'session': self}) def _dispatch_message(self, raw): self.logger.info("Received:\n%s", raw) return super(SSHSession, self)._dispatch_message(raw) def _parse(self): "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again." return self._parse10() def _parse10(self): """Messages are delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a chunk has been read it will not be considered again.""" self.logger.debug("parsing netconf v1.0") buf = self._buffer buf.seek(self._parsing_pos10) if MSG_DELIM in buf.read().decode('UTF-8'): buf.seek(0) msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM) msg = msg.strip() if sys.version < '3': self._dispatch_message(msg.encode()) else: self._dispatch_message(msg) # create new buffer which contains remaining of old buffer self._buffer = StringIO() self._buffer.write(remaining.encode()) self._parsing_pos10 = 0 if len(remaining) > 0: # There could be another entire message in the # buffer, so we should try to parse again. self.logger.debug( 'Trying another round of parsing since there is still data' ) self._parse10() else: # handle case that MSG_DELIM is split over two chunks self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN if self._parsing_pos10 < 0: self._parsing_pos10 = 0 def _parse11(self): """Messages are split into chunks. Chunks and messages are delimited by the regex #RE_NC11_DELIM defined earlier in this file. Each time we get called here either a chunk delimiter or an end-of-message delimiter should be found iff there is enough data. If there is not enough data, we will wait for more. If a delimiter is found in the wrong place, a #NetconfFramingError will be raised.""" self.logger.debug("_parse11: starting") # suck in whole string that we have (this is what we will work on in # this function) and initialize a couple of useful values self._buffer.seek(0, os.SEEK_SET) data = self._buffer.getvalue() data_len = len(data) start = 0 self.logger.debug('_parse11: working with buffer of %d bytes', data_len) while True and start < data_len: # match to see if we found at least some kind of delimiter self.logger.debug( '_parse11: matching from %d bytes from start of buffer', start) re_result = RE_NC11_DELIM.match(data[start:].decode('utf-8')) if not re_result: # not found any kind of delimiter just break; this should only # ever happen if we just have the first few characters of a # message such that we don't yet have a full delimiter self.logger.debug('_parse11: no delimiter found, buffer="%s"', data[start:].decode()) break # save useful variables for reuse re_start = re_result.start() re_end = re_result.end() self.logger.debug('_parse11: regular expression start=%d, end=%d', re_start, re_end) # If the regex doesn't start at the beginning of the buffer, # we're in trouble, so throw an error if re_start != 0: raise NetconfFramingError( '_parse11: delimiter not at start of match buffer', data[start:]) if re_result.group(2): # we've found the end of the message, need to form up # whole message, save back remainder (if any) to buffer # and dispatch the message start += re_end message = ''.join(self._message_list) self._message_list = [] self.logger.debug('_parse11: found end of message delimiter') self._dispatch_message(message) break elif re_result.group(1): # we've found a chunk delimiter, and group(2) is the digit # string that will tell us how many bytes past the end of # where it was found that we need to have available to # save the next chunk off self.logger.debug('_parse11: found chunk delimiter') digits = int(re_result.group(1)) self.logger.debug('_parse11: chunk size %d bytes', digits) if (data_len - start) >= (re_end + digits): # we have enough data for the chunk fragment = textify(data[start + re_end:start + re_end + digits]) self._message_list.append(fragment) start += re_end + digits self.logger.debug('_parse11: appending %d bytes', digits) self.logger.debug('_parse11: fragment = "%s"', fragment) else: # we don't have enough bytes, just break out for now # after updating start pointer to start of new chunk start += re_start self.logger.debug( '_parse11: not enough data for chunk yet') self.logger.debug('_parse11: setting start to %d', start) break # Now out of the loop, need to see if we need to save back any content if start > 0: self.logger.debug( '_parse11: saving back rest of message after %d bytes, original size %d', start, data_len) self._buffer = StringIO(data[start:]) if start < data_len: self.logger.debug( '_parse11: still have data, may have another full message!' ) self._parse11() self.logger.debug('_parse11: ending') def load_known_hosts(self, filename=None): """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times. If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows. """ if filename is None: filename = os.path.expanduser('~/.ssh/known_hosts') try: self._host_keys.load(filename) except IOError: # for windows filename = os.path.expanduser('~/ssh/known_hosts') try: self._host_keys.load(filename) except IOError: pass else: self._host_keys.load(filename) def close(self): self._closing.set() if self._transport.is_active(): self._transport.close() # Wait for the transport thread to close. while self.is_alive() and (self is not threading.current_thread()): self.join(10) if self._channel: self._channel.close() self._channel = None self._connected = False # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there def connect(self, host, port=PORT_NETCONF_DEFAULT, timeout=None, unknown_host_cb=default_unknown_host_cb, username=None, password=None, key_filename=None, allow_agent=True, hostkey_verify=True, hostkey_b64=None, look_for_keys=True, ssh_config=None, sock_fd=None): """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication. To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`. *host* is the hostname or IP address to connect to *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 (PORT_SSH_DEFAULT) so this may need to be specified *timeout* is an optional timeout for socket connect *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`) *username* is the username to use for SSH authentication *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it *key_filename* is a filename where a the private key to be used can be found *allow_agent* enables querying SSH agent (if found) for keys *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan) *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`) *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`). *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number """ if not (host or sock_fd): raise SSHError("Missing host or socket fd") self._host = host # Optionally, parse .ssh/config config = {} if ssh_config is True: ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config" if ssh_config is not None: config = paramiko.SSHConfig() config.parse(open(os.path.expanduser(ssh_config))) # Save default Paramiko SSH port so it can be reverted paramiko_default_ssh_port = paramiko.config.SSH_PORT # Change the default SSH port to the port specified by the user so expand_variables # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT) paramiko.config.SSH_PORT = port config = config.lookup(host) # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port # back to the default paramiko.config.SSH_PORT = paramiko_default_ssh_port host = config.get("hostname", host) if username is None: username = config.get("user") if key_filename is None: key_filename = config.get("identityfile") if hostkey_verify: userknownhostsfile = config.get("userknownhostsfile") if userknownhostsfile: self.load_known_hosts( os.path.expanduser(userknownhostsfile)) if username is None: username = getpass.getuser() if sock_fd is None: if config.get("proxycommand"): self.logger.debug("Configuring Proxy. %s", config.get("proxycommand")) sock = paramiko.proxy.ProxyCommand(config.get("proxycommand")) else: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) sock.settimeout(timeout) except socket.error: continue try: sock.connect(sa) except socket.error: sock.close() continue break else: raise SSHError("Could not open socket to %s:%s" % (host, port)) else: if sys.version_info[0] < 3: s = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, _sock=s) else: sock = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) self._transport = paramiko.Transport(sock) self._transport.set_log_channel(logger.name) if config.get("compression") == 'yes': self._transport.use_compression() if hostkey_b64: # If we need to connect with a specific hostkey, negotiate for only its type hostkey_obj = None for key_cls in [ paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey ]: try: hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64)) except paramiko.SSHException: # Not a key of this type - try the next pass if not hostkey_obj: # We've tried all known host key types and haven't found a suitable one to use - bail raise SSHError( "Couldn't find suitable paramiko key class for host key %s" % hostkey_b64) self._transport._preferred_keys = [hostkey_obj.get_name()] elif self._host_keys: # Else set preferred host keys to those we possess for the host # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation) if port == PORT_SSH_DEFAULT: known_hosts_lookup = host else: known_hosts_lookup = '[%s]:%s' % (host, port) known_host_keys_for_this_host = self._host_keys.lookup( known_hosts_lookup) if known_host_keys_for_this_host: self._transport._preferred_keys = [ x.key.get_name() for x in known_host_keys_for_this_host._entries ] # Connect try: self._transport.start_client() except paramiko.SSHException as e: raise SSHError('Negotiation failed: %s' % e) server_key_obj = self._transport.get_remote_server_key() fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint())) if hostkey_verify: is_known_host = False # For looking up entries for nonstandard (22) ssh ports in known_hosts # we enclose host in brackets and append port number if port == PORT_SSH_DEFAULT: known_hosts_lookup = host else: known_hosts_lookup = '[%s]:%s' % (host, port) if hostkey_b64: # If hostkey specified, remote host /must/ use that hostkey if (hostkey_obj.get_name() == server_key_obj.get_name() and hostkey_obj.asbytes() == server_key_obj.asbytes()): is_known_host = True else: # Check known_hosts is_known_host = self._host_keys.check(known_hosts_lookup, server_key_obj) if not is_known_host and not unknown_host_cb(host, fingerprint): raise SSHUnknownHostError(known_hosts_lookup, fingerprint) # Authenticating with our private key/identity if key_filename is None: key_filenames = [] elif isinstance(key_filename, (str, bytes)): key_filenames = [key_filename] else: key_filenames = key_filename self._auth(username, password, key_filenames, allow_agent, look_for_keys) self._connected = True # there was no error authenticating self._closing.clear() # TODO: leopoul: Review, test, and if needed rewrite this part subsystem_names = self._device_handler.get_ssh_subsystem_names() for subname in subsystem_names: self._channel = self._transport.open_session() self._channel_id = self._channel.get_id() channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id)) self._channel.set_name(channel_name) try: self._channel.invoke_subsystem(subname) except paramiko.SSHException as e: self.logger.info("%s (subsystem request rejected)", e) handle_exception = self._device_handler.handle_connection_exceptions( self) # Ignore the exception, since we continue to try the different # subsystem names until we find one that can connect. # have to handle exception for each vendor here if not handle_exception: continue self._channel_name = self._channel.get_name() self._post_connect() return raise SSHError( "Could not open connection, possibly due to unacceptable" " SSH subsystem name.") def _auth(self, username, password, key_filenames, allow_agent, look_for_keys): saved_exception = None for key_filename in key_filenames: for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey): try: key = cls.from_private_key_file(key_filename, password) self.logger.debug("Trying key %s from %s", hexlify(key.get_fingerprint()), key_filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if allow_agent: for key in paramiko.Agent().get_keys(): try: self.logger.debug("Trying SSH agent key %s", hexlify(key.get_fingerprint())) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) keyfiles = [] if look_for_keys: rsa_key = os.path.expanduser("~/.ssh/id_rsa") dsa_key = os.path.expanduser("~/.ssh/id_dsa") ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) # look in ~/ssh/ for windows users: rsa_key = os.path.expanduser("~/ssh/id_rsa") dsa_key = os.path.expanduser("~/ssh/id_dsa") ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) for cls, filename in keyfiles: try: key = cls.from_private_key_file(filename, password) self.logger.debug("Trying discovered key %s in %s", hexlify(key.get_fingerprint()), filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if password is not None: try: self._transport.auth_password(username, password) return except Exception as e: saved_exception = e self.logger.debug(e) if saved_exception is not None: # need pep-3134 to do this right raise AuthenticationError(repr(saved_exception)) raise AuthenticationError("No authentication methods available") def run(self): chan = self._channel q = self._q def start_delim(data_len): return '\n#%s\n' % (data_len) try: s = selectors.DefaultSelector() s.register(chan, selectors.EVENT_READ) self.logger.debug('selector type = %s', s.__class__.__name__) while True: # Will wakeup evey TICK seconds to check if something # to send, more quickly if something to read (due to # select returning chan in readable list). events = s.select(timeout=TICK) if events: data = chan.recv(BUF_SIZE) if data: self._buffer.seek(0, os.SEEK_END) self._buffer.write(data) if self._base == NetconfBase.BASE_11: self._parse11() else: self._parse10() elif self._closing.is_set(): # End of session, expected break else: # End of session, unexpected raise SessionCloseError(self._buffer.getvalue()) if not q.empty() and chan.send_ready(): self.logger.debug("Sending message") data = q.get() if self._base == NetconfBase.BASE_11: data = "%s%s%s" % (start_delim( len(data)), data, END_DELIM) else: data = "%s%s" % (data, MSG_DELIM) self.logger.info("Sending:\n%s", data) while data: n = chan.send(data) if n <= 0: raise SessionCloseError(self._buffer.getvalue(), data) data = data[n:] except Exception as e: self.logger.debug("Broke out of main loop, error=%r", e) self._dispatch_error(e) self.close() @property def host(self): """Host this session is connected to, or None if not connected.""" if hasattr(self, '_host'): return self._host return None @property def transport(self): "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it." return self._transport
class DefaultXMLParser(object): def __init__(self, session): """ DOM Parser :param session: ssh session object """ self._session = session self._parsing_pos10 = 0 self.logger = SessionLoggerAdapter(logger, {'session': self._session}) def parse(self, data): """ parse incoming RPC response from networking device. :param data: incoming RPC data from device :return: None """ if data: self._session._buffer.seek(0, os.SEEK_END) self._session._buffer.write(data) if self._session._base == NetconfBase.BASE_11: self._parse11() else: self._parse10() def _parse10(self): """Messages are delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a chunk has been read it will not be considered again.""" self.logger.debug("parsing netconf v1.0") buf = self._session._buffer buf.seek(self._parsing_pos10) if MSG_DELIM in buf.read().decode('UTF-8'): buf.seek(0) msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM) msg = msg.strip() if sys.version < '3': self._session._dispatch_message(msg.encode()) else: self._session._dispatch_message(msg) self._session._buffer = StringIO() self._parsing_pos10 = 0 if len(remaining.strip()) > 0: # There could be another entire message in the # buffer, so we should try to parse again. if type(self._session.parser) != DefaultXMLParser: self.logger.debug('send remaining data to SAX parser') self._session.parser.parse(remaining.encode()) else: self.logger.debug( 'Trying another round of parsing since there is still data' ) self._session._buffer.write(remaining.encode()) self._parse10() else: # handle case that MSG_DELIM is split over two chunks self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN if self._parsing_pos10 < 0: self._parsing_pos10 = 0 def _parse11(self): """Messages are split into chunks. Chunks and messages are delimited by the regex #RE_NC11_DELIM defined earlier in this file. Each time we get called here either a chunk delimiter or an end-of-message delimiter should be found iff there is enough data. If there is not enough data, we will wait for more. If a delimiter is found in the wrong place, a #NetconfFramingError will be raised.""" self.logger.debug("_parse11: starting") # suck in whole string that we have (this is what we will work on in # this function) and initialize a couple of useful values self._session._buffer.seek(0, os.SEEK_SET) data = self._session._buffer.getvalue() data_len = len(data) start = 0 self.logger.debug('_parse11: working with buffer of %d bytes', data_len) while True and start < data_len: # match to see if we found at least some kind of delimiter self.logger.debug( '_parse11: matching from %d bytes from start of buffer', start) re_result = RE_NC11_DELIM.match(data[start:].decode('utf-8')) if not re_result: # not found any kind of delimiter just break; this should only # ever happen if we just have the first few characters of a # message such that we don't yet have a full delimiter self.logger.debug('_parse11: no delimiter found, buffer="%s"', data[start:].decode()) break # save useful variables for reuse re_start = re_result.start() re_end = re_result.end() self.logger.debug('_parse11: regular expression start=%d, end=%d', re_start, re_end) # If the regex doesn't start at the beginning of the buffer, # we're in trouble, so throw an error if re_start != 0: raise NetconfFramingError( '_parse11: delimiter not at start of match buffer', data[start:]) if re_result.group(2): # we've found the end of the message, need to form up # whole message, save back remainder (if any) to buffer # and dispatch the message start += re_end message = ''.join(self._session._message_list) self._session._message_list = [] self.logger.debug('_parse11: found end of message delimiter') self._session._dispatch_message(message) break elif re_result.group(1): # we've found a chunk delimiter, and group(2) is the digit # string that will tell us how many bytes past the end of # where it was found that we need to have available to # save the next chunk off self.logger.debug('_parse11: found chunk delimiter') digits = int(re_result.group(1)) self.logger.debug('_parse11: chunk size %d bytes', digits) if (data_len - start) >= (re_end + digits): # we have enough data for the chunk fragment = textify(data[start + re_end:start + re_end + digits]) self._session._message_list.append(fragment) start += re_end + digits self.logger.debug('_parse11: appending %d bytes', digits) self.logger.debug('_parse11: fragment = "%s"', fragment) else: # we don't have enough bytes, just break out for now # after updating start pointer to start of new chunk start += re_start self.logger.debug( '_parse11: not enough data for chunk yet') self.logger.debug('_parse11: setting start to %d', start) break # Now out of the loop, need to see if we need to save back any content if start > 0: self.logger.debug( '_parse11: saving back rest of message after %d bytes, original size %d', start, data_len) self._session._buffer = StringIO(data[start:]) if start < data_len: self.logger.debug( '_parse11: still have data, may have another full message!' ) self._parse11() self.logger.debug('_parse11: ending')
class SSHSession(Session): "Implements a :rfc:`4742` NETCONF session over SSH." def __init__(self, device_handler): capabilities = Capabilities(device_handler.get_capabilities()) Session.__init__(self, capabilities) self._host = None self._host_keys = paramiko.HostKeys() self._transport = None self._connected = False self._channel = None self._channel_id = None self._channel_name = None self._buffer = StringIO() self._device_handler = device_handler self._message_list = [] self._closing = threading.Event() self.parser = DefaultXMLParser(self) # SAX or DOM parser self.logger = SessionLoggerAdapter(logger, {'session': self}) def _dispatch_message(self, raw): # Provide basic response message self.logger.info("Received message from host") # Provide complete response from host at debug log level self.logger.debug("Received:\n%s", raw) return super(SSHSession, self)._dispatch_message(raw) def _parse(self): "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again." return self.parser._parse10() def load_known_hosts(self, filename=None): """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times. If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows. """ if filename is None: filename = os.path.expanduser('~/.ssh/known_hosts') try: self._host_keys.load(filename) except IOError: # for windows filename = os.path.expanduser('~/ssh/known_hosts') try: self._host_keys.load(filename) except IOError: pass else: self._host_keys.load(filename) def close(self): self._closing.set() if self._transport.is_active(): self._transport.close() # Wait for the transport thread to close. while self.is_alive() and (self is not threading.current_thread()): self.join(10) if self._channel: self._channel.close() self._channel = None self._connected = False # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there def connect(self, host, port=PORT_NETCONF_DEFAULT, timeout=None, unknown_host_cb=default_unknown_host_cb, username=None, password=None, key_filename=None, allow_agent=True, hostkey_verify=True, hostkey_b64=None, look_for_keys=True, ssh_config=None, sock_fd=None, bind_addr=None): """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication. To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`. *host* is the hostname or IP address to connect to *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 so this may need to be specified *timeout* is an optional timeout for socket connect *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`) *username* is the username to use for SSH authentication *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it *key_filename* is a filename where a the private key to be used can be found *allow_agent* enables querying SSH agent (if found) for keys *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan) *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`) *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`). *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number *bind_addr* is a (local) source IP address to use, must be reachable from the remote device. """ if not (host or sock_fd): raise SSHError("Missing host or socket fd") self._host = host # Optionally, parse .ssh/config config = {} if ssh_config is True: ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config" if ssh_config is not None: config = paramiko.SSHConfig() with open(os.path.expanduser(ssh_config)) as ssh_config_file_obj: config.parse(ssh_config_file_obj) # Save default Paramiko SSH port so it can be reverted paramiko_default_ssh_port = paramiko.config.SSH_PORT # Change the default SSH port to the port specified by the user so expand_variables # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT) paramiko.config.SSH_PORT = port config = config.lookup(host) # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port # back to the default paramiko.config.SSH_PORT = paramiko_default_ssh_port host = config.get("hostname", host) if username is None: username = config.get("user") if key_filename is None: key_filename = config.get("identityfile") if hostkey_verify: userknownhostsfile = config.get("userknownhostsfile") if userknownhostsfile: self.load_known_hosts( os.path.expanduser(userknownhostsfile)) if timeout is None: timeout = config.get("connecttimeout") if timeout: timeout = int(timeout) if username is None: username = getpass.getuser() if sock_fd is None: proxycommand = config.get("proxycommand") if proxycommand: self.logger.debug("Configuring Proxy. %s", proxycommand) if not isinstance(proxycommand, six.string_types): proxycommand = [ os.path.expanduser(elem) for elem in proxycommand ] else: proxycommand = os.path.expanduser(proxycommand) sock = paramiko.proxy.ProxyCommand(proxycommand) else: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) sock.settimeout(timeout) except socket.error: continue try: if bind_addr: sock.bind((bind_addr, 0)) sock.connect(sa) except socket.error: sock.close() continue break else: raise SSHError("Could not open socket to %s:%s" % (host, port)) else: if sys.version_info[0] < 3: s = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, _sock=s) else: sock = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) self._transport = paramiko.Transport(sock) self._transport.set_log_channel(logger.name) if config.get("compression") == 'yes': self._transport.use_compression() if hostkey_b64: # If we need to connect with a specific hostkey, negotiate for only its type hostkey_obj = None for key_cls in [ paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey ]: try: hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64)) except paramiko.SSHException: # Not a key of this type - try the next pass if not hostkey_obj: # We've tried all known host key types and haven't found a suitable one to use - bail raise SSHError( "Couldn't find suitable paramiko key class for host key %s" % hostkey_b64) self._transport._preferred_keys = [hostkey_obj.get_name()] elif self._host_keys: # Else set preferred host keys to those we possess for the host # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation) known_host_keys_for_this_host = self._host_keys.lookup(host) or {} host_port = '[%s]:%s' % (host, port) known_host_keys_for_this_host.update( self._host_keys.lookup(host_port) or {}) if known_host_keys_for_this_host: self._transport._preferred_keys = [ x.key.get_name() for x in known_host_keys_for_this_host._entries ] # Connect try: self._transport.start_client() except paramiko.SSHException as e: raise SSHError('Negotiation failed: %s' % e) server_key_obj = self._transport.get_remote_server_key() fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint())) if hostkey_verify: is_known_host = False # For looking up entries for nonstandard (22) ssh ports in known_hosts # we enclose host in brackets and append port number known_hosts_lookups = [host, '[%s]:%s' % (host, port)] if hostkey_b64: # If hostkey specified, remote host /must/ use that hostkey if (hostkey_obj.get_name() == server_key_obj.get_name() and hostkey_obj.asbytes() == server_key_obj.asbytes()): is_known_host = True else: # Check known_hosts is_known_host = any( self._host_keys.check(lookup, server_key_obj) for lookup in known_hosts_lookups) if not is_known_host and not unknown_host_cb(host, fingerprint): raise SSHUnknownHostError(known_hosts_lookups[0], fingerprint) # Authenticating with our private key/identity if key_filename is None: key_filenames = [] elif isinstance(key_filename, (str, bytes)): key_filenames = [key_filename] else: key_filenames = key_filename self._auth(username, password, key_filenames, allow_agent, look_for_keys) self._connected = True # there was no error authenticating self._closing.clear() # TODO: leopoul: Review, test, and if needed rewrite this part subsystem_names = self._device_handler.get_ssh_subsystem_names() for subname in subsystem_names: self._channel = self._transport.open_session() self._channel_id = self._channel.get_id() channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id)) self._channel.set_name(channel_name) try: self._channel.invoke_subsystem(subname) except paramiko.SSHException as e: self.logger.info("%s (subsystem request rejected)", e) handle_exception = self._device_handler.handle_connection_exceptions( self) # Ignore the exception, since we continue to try the different # subsystem names until we find one that can connect. # have to handle exception for each vendor here if not handle_exception: continue self._channel_name = self._channel.get_name() self._post_connect() # for further upcoming RPC responses, vendor can chose their # choice of parser. Say DOM or SAX self.parser = self._device_handler.get_xml_parser(self) return raise SSHError( "Could not open connection, possibly due to unacceptable" " SSH subsystem name.") def _auth(self, username, password, key_filenames, allow_agent, look_for_keys): saved_exception = None for key_filename in key_filenames: for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey, paramiko.Ed25519Key): try: key = cls.from_private_key_file(key_filename, password) self.logger.debug("Trying key %s from %s", hexlify(key.get_fingerprint()), key_filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if allow_agent: for key in paramiko.Agent().get_keys(): try: self.logger.debug("Trying SSH agent key %s", hexlify(key.get_fingerprint())) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) keyfiles = [] if look_for_keys: rsa_key = os.path.expanduser("~/.ssh/id_rsa") dsa_key = os.path.expanduser("~/.ssh/id_dsa") ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) # look in ~/ssh/ for windows users: rsa_key = os.path.expanduser("~/ssh/id_rsa") dsa_key = os.path.expanduser("~/ssh/id_dsa") ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) for cls, filename in keyfiles: try: key = cls.from_private_key_file(filename, password) self.logger.debug("Trying discovered key %s in %s", hexlify(key.get_fingerprint()), filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if password is not None: try: self._transport.auth_password(username, password) return except Exception as e: saved_exception = e self.logger.debug(e) if saved_exception is not None: # need pep-3134 to do this right raise AuthenticationError(repr(saved_exception)) raise AuthenticationError("No authentication methods available") def run(self): chan = self._channel q = self._q def start_delim(data_len): return '\n#%s\n' % (data_len) try: s = selectors.DefaultSelector() s.register(chan, selectors.EVENT_READ) self.logger.debug('selector type = %s', s.__class__.__name__) while True: # Will wakeup evey TICK seconds to check if something # to send, more quickly if something to read (due to # select returning chan in readable list). events = s.select(timeout=TICK) if events: data = chan.recv(BUF_SIZE) if data: try: self.parser.parse(data) except SAXFilterXMLNotFoundError: self.logger.debug( 'switching from sax to dom parsing') self.parser = DefaultXMLParser(self) self.parser.parse(data) elif self._closing.is_set(): # End of session, expected break else: # End of session, unexpected raise SessionCloseError(self._buffer.getvalue()) if not q.empty() and chan.send_ready(): self.logger.debug("Sending message") data = q.get() if self._base == NetconfBase.BASE_11: data = "%s%s%s" % (start_delim( len(data)), data, END_DELIM) else: data = "%s%s" % (data, MSG_DELIM) self.logger.info("Sending:\n%s", data) while data: n = chan.send(data) if n <= 0: raise SessionCloseError(self._buffer.getvalue(), data) data = data[n:] except Exception as e: self.logger.debug("Broke out of main loop, error=%r", e) self._dispatch_error(e) self.close() @property def host(self): """Host this session is connected to, or None if not connected.""" if hasattr(self, '_host'): return self._host return None @property def transport(self): "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it." return self._transport
class SSHSession(Session): "Implements a :rfc:`4742` NETCONF session over SSH." def __init__(self, device_handler): capabilities = Capabilities(device_handler.get_capabilities()) Session.__init__(self, capabilities) self._host = None self._host_keys = paramiko.HostKeys() self._transport = None self._connected = False self._channel = None self._channel_id = None self._channel_name = None self._buffer = StringIO() # parsing-related, see _parse() self._device_handler = device_handler self._parsing_state10 = 0 self._parsing_pos10 = 0 self._parsing_pos11 = 0 self._parsing_state11 = 0 self._expchunksize = 0 self._curchunksize = 0 self._inendpos = 0 self._size_num_list = [] self._message_list = [] self._closing = threading.Event() self.logger = SessionLoggerAdapter(logger, {'session': self}) def _dispatch_message(self, raw): self.logger.info("Received:\n%s", raw) return super(SSHSession, self)._dispatch_message(raw) def _parse(self): "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again." return self._parse10() def _parse10(self): """Messages are delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a chunk has been read it will not be considered again.""" self.logger.debug("parsing netconf v1.0") buf = self._buffer buf.seek(self._parsing_pos10) if MSG_DELIM in buf.read().decode('UTF-8'): buf.seek(0) msg, _, remaining = buf.read().decode('UTF-8').partition(MSG_DELIM) msg = msg.strip() if sys.version < '3': self._dispatch_message(msg.encode()) else: self._dispatch_message(msg) # create new buffer which contains remaining of old buffer self._buffer = StringIO() self._buffer.write(remaining.encode()) self._parsing_pos10 = 0 if len(remaining) > 0: # There could be another entire message in the # buffer, so we should try to parse again. self.logger.debug('Trying another round of parsing since there is still data') self._parse10() else: # handle case that MSG_DELIM is split over two chunks self._parsing_pos10 = buf.tell() - MSG_DELIM_LEN if self._parsing_pos10 < 0: self._parsing_pos10 = 0 def _parse11(self): """Messages are split into chunks. Chunks and messages are delimited by the regex #RE_NC11_DELIM defined earlier in this file. Each time we get called here either a chunk delimiter or an end-of-message delimiter should be found iff there is enough data. If there is not enough data, we will wait for more. If a delimiter is found in the wrong place, a #NetconfFramingError will be raised.""" self.logger.debug("_parse11: starting") # suck in whole string that we have (this is what we will work on in # this function) and initialize a couple of useful values self._buffer.seek(0, os.SEEK_SET) data = self._buffer.getvalue() data_len = len(data) start = 0 self.logger.debug('_parse11: working with buffer of %d bytes', data_len) while True and start < data_len: # match to see if we found at least some kind of delimiter self.logger.debug('_parse11: matching from %d bytes from start of buffer', start) re_result = RE_NC11_DELIM.match(data[start:]) if not re_result: # not found any kind of delimiter just break; this should only # ever happen if we just have the first few characters of a # message such that we don't yet have a full delimiter self.logger.debug('_parse11: no delimiter found, buffer="%s"', data[start:].decode()) break # save useful variables for reuse re_start = re_result.start() re_end = re_result.end() self.logger.debug('_parse11: regular expression start=%d, end=%d', re_start, re_end) # If the regex doesn't start at the beginning of the buffer, # we're in trouble, so throw an error if re_start != 0: raise NetconfFramingError('_parse11: delimiter not at start of match buffer', data[start:]) if re_result.group(2): # we've found the end of the message, need to form up # whole message, save back remainder (if any) to buffer # and dispatch the message start += re_end message = ''.join(self._message_list) self._message_list = [] self.logger.debug('_parse11: found end of message delimiter') self._dispatch_message(message) break elif re_result.group(1): # we've found a chunk delimiter, and group(2) is the digit # string that will tell us how many bytes past the end of # where it was found that we need to have available to # save the next chunk off self.logger.debug('_parse11: found chunk delimiter') digits = int(re_result.group(1)) self.logger.debug('_parse11: chunk size %d bytes', digits) if (data_len-start) >= (re_end + digits): # we have enough data for the chunk fragment = textify(data[start+re_end:start+re_end+digits]) self._message_list.append(fragment) start += re_end + digits self.logger.debug('_parse11: appending %d bytes', digits) self.logger.debug('_parse11: fragment = "%s"', fragment) else: # we don't have enough bytes, just break out for now # after updating start pointer to start of new chunk start += re_start self.logger.debug('_parse11: not enough data for chunk yet') self.logger.debug('_parse11: setting start to %d', start) break # Now out of the loop, need to see if we need to save back any content if start > 0: self.logger.debug( '_parse11: saving back rest of message after %d bytes, original size %d', start, data_len) self._buffer = StringIO(data[start:]) if start < data_len: self.logger.debug('_parse11: still have data, may have another full message!') self._parse11() self.logger.debug('_parse11: ending') def load_known_hosts(self, filename=None): """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times. If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows. """ if filename is None: filename = os.path.expanduser('~/.ssh/known_hosts') try: self._host_keys.load(filename) except IOError: # for windows filename = os.path.expanduser('~/ssh/known_hosts') try: self._host_keys.load(filename) except IOError: pass else: self._host_keys.load(filename) def close(self): self._closing.set() if self._transport.is_active(): self._transport.close() # Wait for the transport thread to close. while self.is_alive() and (self is not threading.current_thread()): self.join(10) if self._channel: self._channel.close() self._channel = None self._connected = False # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there def connect( self, host, port = PORT_NETCONF_DEFAULT, timeout = None, unknown_host_cb = default_unknown_host_cb, username = None, password = None, key_filename = None, allow_agent = True, hostkey_verify = True, hostkey_b64 = None, look_for_keys = True, ssh_config = None, sock_fd = None): """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication. To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`. *host* is the hostname or IP address to connect to *port* is by default 830 (PORT_NETCONF_DEFAULT), but some devices use the default SSH port of 22 (PORT_SSH_DEFAULT) so this may need to be specified *timeout* is an optional timeout for socket connect *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`) *username* is the username to use for SSH authentication *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it *key_filename* is a filename where a the private key to be used can be found *allow_agent* enables querying SSH agent (if found) for keys *hostkey_verify* enables hostkey verification from ~/.ssh/known_hosts *hostkey_b64* only connect when server presents a public hostkey matching this (obtain from server /etc/ssh/ssh_host_*pub or ssh-keyscan) *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`) *ssh_config* enables parsing of an OpenSSH configuration file, if set to its path, e.g. :file:`~/.ssh/config` or to True (in this case, use :file:`~/.ssh/config`). *sock_fd* is an already open socket which shall be used for this connection. Useful for NETCONF outbound ssh. Use host=None together with a valid sock_fd number """ if not (host or sock_fd): raise SSHError("Missing host or socket fd") self._host = host # Optionally, parse .ssh/config config = {} if ssh_config is True: ssh_config = "~/.ssh/config" if sys.platform != "win32" else "~/ssh/config" if ssh_config is not None: config = paramiko.SSHConfig() config.parse(open(os.path.expanduser(ssh_config))) # Save default Paramiko SSH port so it can be reverted paramiko_default_ssh_port = paramiko.config.SSH_PORT # Change the default SSH port to the port specified by the user so expand_variables # replaces %p with the passed in port rather than 22 (the defauld paramiko.config.SSH_PORT) paramiko.config.SSH_PORT = port config = config.lookup(host) # paramiko.config.SSHconfig::expand_variables is called by lookup so we can set the SSH port # back to the default paramiko.config.SSH_PORT = paramiko_default_ssh_port host = config.get("hostname", host) if username is None: username = config.get("user") if key_filename is None: key_filename = config.get("identityfile") if hostkey_verify: userknownhostsfile = config.get("userknownhostsfile") if userknownhostsfile: self.load_known_hosts(os.path.expanduser(userknownhostsfile)) if timeout is None: timeout = config.get("connecttimeout") if timeout: timeout = int(timeout) if username is None: username = getpass.getuser() if sock_fd is None: proxycommand = config.get("proxycommand") if proxycommand: self.logger.debug("Configuring Proxy. %s", proxycommand) if not isinstance(proxycommand, six.string_types): proxycommand = [os.path.expanduser(elem) for elem in proxycommand] else: proxycommand = os.path.expanduser(proxycommand) sock = paramiko.proxy.ProxyCommand(proxycommand) else: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) sock.settimeout(timeout) except socket.error: continue try: sock.connect(sa) except socket.error: sock.close() continue break else: raise SSHError("Could not open socket to %s:%s" % (host, port)) else: if sys.version_info[0] < 3: s = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, _sock=s) else: sock = socket.fromfd(int(sock_fd), socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) self._transport = paramiko.Transport(sock) self._transport.set_log_channel(logger.name) if config.get("compression") == 'yes': self._transport.use_compression() if hostkey_b64: # If we need to connect with a specific hostkey, negotiate for only its type hostkey_obj = None for key_cls in [paramiko.DSSKey, paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]: try: hostkey_obj = key_cls(data=base64.b64decode(hostkey_b64)) except paramiko.SSHException: # Not a key of this type - try the next pass if not hostkey_obj: # We've tried all known host key types and haven't found a suitable one to use - bail raise SSHError("Couldn't find suitable paramiko key class for host key %s" % hostkey_b64) self._transport._preferred_keys = [hostkey_obj.get_name()] elif self._host_keys: # Else set preferred host keys to those we possess for the host # (avoids situation where known_hosts contains a valid key for the host, but that key type is not selected during negotiation) if port == PORT_SSH_DEFAULT: known_hosts_lookup = host else: known_hosts_lookup = '[%s]:%s' % (host, port) known_host_keys_for_this_host = self._host_keys.lookup(known_hosts_lookup) if known_host_keys_for_this_host: self._transport._preferred_keys = [x.key.get_name() for x in known_host_keys_for_this_host._entries] # Connect try: self._transport.start_client() except paramiko.SSHException as e: raise SSHError('Negotiation failed: %s' % e) server_key_obj = self._transport.get_remote_server_key() fingerprint = _colonify(hexlify(server_key_obj.get_fingerprint())) if hostkey_verify: is_known_host = False # For looking up entries for nonstandard (22) ssh ports in known_hosts # we enclose host in brackets and append port number if port == PORT_SSH_DEFAULT: known_hosts_lookup = host else: known_hosts_lookup = '[%s]:%s' % (host, port) if hostkey_b64: # If hostkey specified, remote host /must/ use that hostkey if(hostkey_obj.get_name() == server_key_obj.get_name() and hostkey_obj.asbytes() == server_key_obj.asbytes()): is_known_host = True else: # Check known_hosts is_known_host = self._host_keys.check(known_hosts_lookup, server_key_obj) if not is_known_host and not unknown_host_cb(host, fingerprint): raise SSHUnknownHostError(known_hosts_lookup, fingerprint) # Authenticating with our private key/identity if key_filename is None: key_filenames = [] elif isinstance(key_filename, (str, bytes)): key_filenames = [key_filename] else: key_filenames = key_filename self._auth(username, password, key_filenames, allow_agent, look_for_keys) self._connected = True # there was no error authenticating self._closing.clear() # TODO: leopoul: Review, test, and if needed rewrite this part subsystem_names = self._device_handler.get_ssh_subsystem_names() for subname in subsystem_names: self._channel = self._transport.open_session() self._channel_id = self._channel.get_id() channel_name = "%s-subsystem-%s" % (subname, str(self._channel_id)) self._channel.set_name(channel_name) try: self._channel.invoke_subsystem(subname) except paramiko.SSHException as e: self.logger.info("%s (subsystem request rejected)", e) handle_exception = self._device_handler.handle_connection_exceptions(self) # Ignore the exception, since we continue to try the different # subsystem names until we find one that can connect. # have to handle exception for each vendor here if not handle_exception: continue self._channel_name = self._channel.get_name() self._post_connect() return raise SSHError("Could not open connection, possibly due to unacceptable" " SSH subsystem name.") def _auth(self, username, password, key_filenames, allow_agent, look_for_keys): saved_exception = None for key_filename in key_filenames: for cls in (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey): try: key = cls.from_private_key_file(key_filename, password) self.logger.debug("Trying key %s from %s", hexlify(key.get_fingerprint()), key_filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if allow_agent: for key in paramiko.Agent().get_keys(): try: self.logger.debug("Trying SSH agent key %s", hexlify(key.get_fingerprint())) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) keyfiles = [] if look_for_keys: rsa_key = os.path.expanduser("~/.ssh/id_rsa") dsa_key = os.path.expanduser("~/.ssh/id_dsa") ecdsa_key = os.path.expanduser("~/.ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) # look in ~/ssh/ for windows users: rsa_key = os.path.expanduser("~/ssh/id_rsa") dsa_key = os.path.expanduser("~/ssh/id_dsa") ecdsa_key = os.path.expanduser("~/ssh/id_ecdsa") if os.path.isfile(rsa_key): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) if os.path.isfile(ecdsa_key): keyfiles.append((paramiko.ECDSAKey, ecdsa_key)) for cls, filename in keyfiles: try: key = cls.from_private_key_file(filename, password) self.logger.debug("Trying discovered key %s in %s", hexlify(key.get_fingerprint()), filename) self._transport.auth_publickey(username, key) return except Exception as e: saved_exception = e self.logger.debug(e) if password is not None: try: self._transport.auth_password(username, password) return except Exception as e: saved_exception = e self.logger.debug(e) if saved_exception is not None: # need pep-3134 to do this right raise AuthenticationError(repr(saved_exception)) raise AuthenticationError("No authentication methods available") def run(self): chan = self._channel q = self._q def start_delim(data_len): return '\n#%s\n' % (data_len) try: s = selectors.DefaultSelector() s.register(chan, selectors.EVENT_READ) self.logger.debug('selector type = %s', s.__class__.__name__) while True: # Will wakeup evey TICK seconds to check if something # to send, more quickly if something to read (due to # select returning chan in readable list). events = s.select(timeout=TICK) if events: data = chan.recv(BUF_SIZE) if data: self._buffer.seek(0, os.SEEK_END) self._buffer.write(data) if self._base == NetconfBase.BASE_11: self._parse11() else: self._parse10() elif self._closing.is_set(): # End of session, expected break else: # End of session, unexpected raise SessionCloseError(self._buffer.getvalue()) if not q.empty() and chan.send_ready(): self.logger.debug("Sending message") data = q.get() if self._base == NetconfBase.BASE_11: data = "%s%s%s" % (start_delim(len(data)), data, END_DELIM) else: data = "%s%s" % (data, MSG_DELIM) self.logger.info("Sending:\n%s", data) while data: n = chan.send(data) if n <= 0: raise SessionCloseError(self._buffer.getvalue(), data) data = data[n:] except Exception as e: self.logger.debug("Broke out of main loop, error=%r", e) self._dispatch_error(e) self.close() @property def host(self): """Host this session is connected to, or None if not connected.""" if hasattr(self, '_host'): return self._host return None @property def transport(self): "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it." return self._transport
class Session(Thread): "Base class for use by transport protocol implementations." def __init__(self, capabilities): Thread.__init__(self) self.setDaemon(True) self._listeners = set() self._lock = Lock() self.setName('session') self._q = Queue() self._notification_q = Queue() self._client_capabilities = capabilities self._server_capabilities = None # yet self._base = NetconfBase.BASE_10 self._id = None # session-id self._connected = False # to be set/cleared by subclass implementation self.logger = SessionLoggerAdapter(logger, {'session': self}) self.logger.debug('%r created: client_capabilities=%r', self, self._client_capabilities) self._device_handler = None # Should be set by child class def _dispatch_message(self, raw): try: root = parse_root(raw) except Exception as e: device_handled_raw=self._device_handler.handle_raw_dispatch(raw) if isinstance(device_handled_raw, str): root = parse_root(device_handled_raw) elif isinstance(device_handled_raw, Exception): self._dispatch_error(device_handled_raw) return else: self.logger.error('error parsing dispatch message: %s', e) return with self._lock: listeners = list(self._listeners) for l in listeners: self.logger.debug('dispatching message to %r: %s', l, raw) l.callback(root, raw) # no try-except; fail loudly if you must! def _dispatch_error(self, err): with self._lock: listeners = list(self._listeners) for l in listeners: self.logger.debug('dispatching error to %r', l) try: # here we can be more considerate with catching exceptions l.errback(err) except Exception as e: self.logger.warning('error dispatching to %r: %r', l, e) def _post_connect(self): "Greeting stuff" init_event = Event() error = [None] # so that err_cb can bind error[0]. just how it is. # callbacks def ok_cb(id, capabilities): self._id = id self._server_capabilities = capabilities init_event.set() def err_cb(err): error[0] = err init_event.set() self.add_listener(NotificationHandler(self._notification_q)) listener = HelloHandler(ok_cb, err_cb) self.add_listener(listener) self.send(HelloHandler.build(self._client_capabilities, self._device_handler)) self.logger.debug('starting main loop') self.start() # we expect server's hello message, if server doesn't responds in 60 seconds raise exception init_event.wait(60) if not init_event.is_set(): raise SessionError("Capability exchange timed out") # received hello message or an error happened self.remove_listener(listener) if error[0]: raise error[0] #if ':base:1.0' not in self.server_capabilities: # raise MissingCapabilityError(':base:1.0') if 'urn:ietf:params:netconf:base:1.1' in self._server_capabilities and 'urn:ietf:params:netconf:base:1.1' in self._client_capabilities: self.logger.debug("After 'hello' message selecting netconf:base:1.1 for encoding") self._base = NetconfBase.BASE_11 self.logger.info('initialized: session-id=%s | server_capabilities=%s', self._id, self._server_capabilities) def add_listener(self, listener): """Register a listener that will be notified of incoming messages and errors. :type listener: :class:`SessionListener` """ self.logger.debug('installing listener %r', listener) if not isinstance(listener, SessionListener): raise SessionError("Listener must be a SessionListener type") with self._lock: self._listeners.add(listener) def remove_listener(self, listener): """Unregister some listener; ignore if the listener was never registered. :type listener: :class:`SessionListener` """ self.logger.debug('discarding listener %r', listener) with self._lock: self._listeners.discard(listener) def get_listener_instance(self, cls): """If a listener of the specified type is registered, returns the instance. :type cls: :class:`SessionListener` """ with self._lock: for listener in self._listeners: if isinstance(listener, cls): return listener def connect(self, *args, **kwds): # subclass implements raise NotImplementedError def run(self): # subclass implements raise NotImplementedError def send(self, message): """Send the supplied *message* (xml string) to NETCONF server.""" if not self.connected: raise TransportError('Not connected to NETCONF server') self.logger.debug('queueing %s', message) self._q.put(message) def scp(self): raise NotImplementedError ### Properties def take_notification(self, block, timeout): try: return self._notification_q.get(block, timeout) except Empty: return None @property def connected(self): "Connection status of the session." return self._connected @property def client_capabilities(self): "Client's :class:`Capabilities`" return self._client_capabilities @property def server_capabilities(self): "Server's :class:`Capabilities`" return self._server_capabilities @property def id(self): """A string representing the `session-id`. If the session has not been initialized it will be `None`""" return self._id
class RPC(object): """Base class for all operations, directly corresponding to *rpc* requests. Handles making the request, and taking delivery of the reply.""" DEPENDS = [] """Subclasses can specify their dependencies on capabilities as a list of URI's or abbreviated names, e.g. ':writable-running'. These are verified at the time of instantiation. If the capability is not available, :exc:`MissingCapabilityError` is raised.""" REPLY_CLS = RPCReply "By default :class:`RPCReply`. Subclasses can specify a :class:`RPCReply` subclass." def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE): """ *session* is the :class:`~ncclient.transport.Session` instance *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance *async* specifies whether the request is to be made asynchronously, see :attr:`is_async` *timeout* is the timeout for a synchronous request, see :attr:`timeout` *raise_mode* specifies the exception raising mode, see :attr:`raise_mode` """ self._session = session try: for cap in self.DEPENDS: self._assert(cap) except AttributeError: pass self._async = async_mode self._timeout = timeout self._raise_mode = raise_mode self._id = uuid4().urn # Keeps things simple instead of having a class attr with running ID that has to be locked self._listener = RPCReplyListener(session, device_handler) self._listener.register(self._id, self) self._reply = None self._error = None self._event = Event() self._device_handler = device_handler self.logger = SessionLoggerAdapter(logger, {'session': session}) def _wrap(self, subele): # internal use ele = new_ele("rpc", {"message-id": self._id}, **self._device_handler.get_xml_extra_prefix_kwargs()) ele.append(subele) #print to_xml(ele) return to_xml(ele) def _request(self, op): """Implementations of :meth:`request` call this method to send the request and process the reply. In synchronous mode, blocks until the reply is received and returns :class:`RPCReply`. Depending on the :attr:`raise_mode` a `rpc-error` element in the reply may lead to an :exc:`RPCError` exception. In asynchronous mode, returns immediately, returning `self`. The :attr:`event` attribute will be set when the reply has been received (see :attr:`reply`) or an error occured (see :attr:`error`). *op* is the operation to be requested as an :class:`~xml.etree.ElementTree.Element` """ self.logger.info('Requesting %r', self.__class__.__name__) req = self._wrap(op) self._session.send(req) if self._async: self.logger.debug('Async request, returning %r', self) return self else: self.logger.debug('Sync request, will wait for timeout=%r', self._timeout) self._event.wait(self._timeout) if self._event.isSet(): if self._error: # Error that prevented reply delivery raise self._error self._reply.parse() if self._reply.error is not None and not self._device_handler.is_rpc_error_exempt(self._reply.error.message): # <rpc-error>'s [ RPCError ] if self._raise_mode == RaiseMode.ALL or (self._raise_mode == RaiseMode.ERRORS and self._reply.error.severity == "error"): errlist = [] errors = self._reply.errors if len(errors) > 1: raise RPCError(to_ele(self._reply._raw), errs=errors) else: raise self._reply.error if self._device_handler.transform_reply(): return NCElement(self._reply, self._device_handler.transform_reply()) else: return self._reply else: raise TimeoutExpiredError('ncclient timed out while waiting for an rpc reply.') def request(self): """Subclasses must implement this method. Typically only the request needs to be built as an :class:`~xml.etree.ElementTree.Element` and everything else can be handed off to :meth:`_request`.""" pass def _assert(self, capability): """Subclasses can use this method to verify that a capability is available with the NETCONF server, before making a request that requires it. A :exc:`MissingCapabilityError` will be raised if the capability is not available.""" if capability not in self._session.server_capabilities: raise MissingCapabilityError('Server does not support [%s]' % capability) def deliver_reply(self, raw): # internal use self._reply = self.REPLY_CLS(raw) self._event.set() def deliver_error(self, err): # internal use self._error = err self._event.set() @property def reply(self): ":class:`RPCReply` element if reply has been received or `None`" return self._reply @property def error(self): """:exc:`Exception` type if an error occured or `None`. .. note:: This represents an error which prevented a reply from being received. An *rpc-error* does not fall in that category -- see `RPCReply` for that. """ return self._error @property def id(self): "The *message-id* for this RPC." return self._id @property def session(self): "The `~ncclient.transport.Session` object associated with this RPC." return self._session @property def event(self): """:class:`~threading.Event` that is set when reply has been received or when an error preventing delivery of the reply occurs. """ return self._event def __set_async(self, async_mode=True): self._async = async_mode if async_mode and not self._session.can_pipeline: raise UserWarning('Asynchronous mode not supported for this device/session') def __set_raise_mode(self, mode): assert(mode in (RaiseMode.NONE, RaiseMode.ERRORS, RaiseMode.ALL)) self._raise_mode = mode def __set_timeout(self, timeout): self._timeout = timeout raise_mode = property(fget=lambda self: self._raise_mode, fset=__set_raise_mode) """Depending on this exception raising mode, an `rpc-error` in the reply may be raised as an :exc:`RPCError` exception. Valid values are the constants defined in :class:`RaiseMode`. """ is_async = property(fget=lambda self: self._async, fset=__set_async) """Specifies whether this RPC will be / was requested asynchronously. By default RPC's are synchronous.""" timeout = property(fget=lambda self: self._timeout, fset=__set_timeout) """Timeout in seconds for synchronous waiting defining how long the RPC request will block on a reply before raising :exc:`TimeoutExpiredError`.
class RPC(object): """Base class for all operations, directly corresponding to *rpc* requests. Handles making the request, and taking delivery of the reply.""" DEPENDS = [] """Subclasses can specify their dependencies on capabilities as a list of URI's or abbreviated names, e.g. ':writable-running'. These are verified at the time of instantiation. If the capability is not available, :exc:`MissingCapabilityError` is raised.""" REPLY_CLS = RPCReply "By default :class:`RPCReply`. Subclasses can specify a :class:`RPCReply` subclass." def __init__(self, session, device_handler, async_mode=False, timeout=30, raise_mode=RaiseMode.NONE, huge_tree=False): """ *session* is the :class:`~ncclient.transport.Session` instance *device_handler" is the :class:`~ncclient.devices.*.*DeviceHandler` instance *async* specifies whether the request is to be made asynchronously, see :attr:`is_async` *timeout* is the timeout for a synchronous request, see :attr:`timeout` *raise_mode* specifies the exception raising mode, see :attr:`raise_mode` *huge_tree* parse xml with huge_tree support (e.g. for large text config retrieval), see :attr:`huge_tree` """ self._session = session try: for cap in self.DEPENDS: self._assert(cap) except AttributeError: pass self._async = async_mode self._timeout = timeout self._raise_mode = raise_mode self._huge_tree = huge_tree self._id = uuid4( ).urn # Keeps things simple instead of having a class attr with running ID that has to be locked self._listener = RPCReplyListener(session, device_handler) self._listener.register(self._id, self) self._reply = None self._error = None self._event = Event() self._device_handler = device_handler self.logger = SessionLoggerAdapter(logger, {'session': session}) def _wrap(self, subele): # internal use ele = new_ele("rpc", {"message-id": self._id}, **self._device_handler.get_xml_extra_prefix_kwargs()) ele.append(subele) return to_xml(ele) def _request(self, op): """Implementations of :meth:`request` call this method to send the request and process the reply. In synchronous mode, blocks until the reply is received and returns :class:`RPCReply`. Depending on the :attr:`raise_mode` a `rpc-error` element in the reply may lead to an :exc:`RPCError` exception. In asynchronous mode, returns immediately, returning `self`. The :attr:`event` attribute will be set when the reply has been received (see :attr:`reply`) or an error occured (see :attr:`error`). *op* is the operation to be requested as an :class:`~xml.etree.ElementTree.Element` """ self.logger.info('Requesting %r', self.__class__.__name__) req = self._wrap(op) self._session.send(req) if self._async: self.logger.debug('Async request, returning %r', self) return self else: self.logger.debug('Sync request, will wait for timeout=%r', self._timeout) self._event.wait(self._timeout) if self._event.isSet(): if self._error: # Error that prevented reply delivery raise self._error self._reply.parse() if self._reply.error is not None and not self._device_handler.is_rpc_error_exempt( self._reply.error.message): # <rpc-error>'s [ RPCError ] if self._raise_mode == RaiseMode.ALL or ( self._raise_mode == RaiseMode.ERRORS and self._reply.error.severity == "error"): errlist = [] errors = self._reply.errors if len(errors) > 1: raise RPCError(to_ele(self._reply._raw), errs=errors) else: raise self._reply.error if self._device_handler.transform_reply(): return NCElement(self._reply, self._device_handler.transform_reply(), huge_tree=self._huge_tree) else: return self._reply else: raise TimeoutExpiredError( 'ncclient timed out while waiting for an rpc reply.') def request(self): """Subclasses must implement this method. Typically only the request needs to be built as an :class:`~xml.etree.ElementTree.Element` and everything else can be handed off to :meth:`_request`.""" pass def _assert(self, capability): """Subclasses can use this method to verify that a capability is available with the NETCONF server, before making a request that requires it. A :exc:`MissingCapabilityError` will be raised if the capability is not available.""" if capability not in self._session.server_capabilities: raise MissingCapabilityError('Server does not support [%s]' % capability) def deliver_reply(self, raw): # internal use self._reply = self.REPLY_CLS(raw, huge_tree=self._huge_tree) # Set the reply_parsing_error transform outside the constructor, to keep compatibility for # third party reply classes outside of ncclient self._reply.set_parsing_error_transform( self._device_handler.reply_parsing_error_transform(self.REPLY_CLS)) self._event.set() def deliver_error(self, err): # internal use self._error = err self._event.set() @property def reply(self): ":class:`RPCReply` element if reply has been received or `None`" return self._reply @property def error(self): """:exc:`Exception` type if an error occured or `None`. .. note:: This represents an error which prevented a reply from being received. An *rpc-error* does not fall in that category -- see `RPCReply` for that. """ return self._error @property def id(self): "The *message-id* for this RPC." return self._id @property def session(self): "The `~ncclient.transport.Session` object associated with this RPC." return self._session @property def event(self): """:class:`~threading.Event` that is set when reply has been received or when an error preventing delivery of the reply occurs. """ return self._event def __set_async(self, async_mode=True): self._async = async_mode if async_mode and not self._session.can_pipeline: raise UserWarning( 'Asynchronous mode not supported for this device/session') def __set_raise_mode(self, mode): assert (mode in (RaiseMode.NONE, RaiseMode.ERRORS, RaiseMode.ALL)) self._raise_mode = mode def __set_timeout(self, timeout): self._timeout = timeout raise_mode = property(fget=lambda self: self._raise_mode, fset=__set_raise_mode) """Depending on this exception raising mode, an `rpc-error` in the reply may be raised as an :exc:`RPCError` exception. Valid values are the constants defined in :class:`RaiseMode`. """ is_async = property(fget=lambda self: self._async, fset=__set_async) """Specifies whether this RPC will be / was requested asynchronously. By default RPC's are synchronous.""" timeout = property(fget=lambda self: self._timeout, fset=__set_timeout) """Timeout in seconds for synchronous waiting defining how long the RPC request will block on a reply before raising :exc:`TimeoutExpiredError`. Irrelevant for asynchronous usage. """ @property def huge_tree(self): """Whether `huge_tree` support for XML parsing of RPC replies is enabled (default=False)""" return self._huge_tree @huge_tree.setter def huge_tree(self, x): self._huge_tree = x
class Session(Thread): "Base class for use by transport protocol implementations." def __init__(self, capabilities): Thread.__init__(self) self.setDaemon(True) self._listeners = set() self._lock = Lock() self.setName('session') self._q = Queue() self._notification_q = Queue() self._client_capabilities = capabilities self._server_capabilities = None # yet self._base = NetconfBase.BASE_10 self._id = None # session-id self._connected = False # to be set/cleared by subclass implementation self.logger = SessionLoggerAdapter(logger, {'session': self}) self.logger.debug('%r created: client_capabilities=%r', self, self._client_capabilities) self._device_handler = None # Should be set by child class def _dispatch_message(self, raw): try: root = parse_root(raw) except Exception as e: device_handled_raw = self._device_handler.handle_raw_dispatch(raw) if isinstance(device_handled_raw, str): root = parse_root(device_handled_raw) elif isinstance(device_handled_raw, Exception): self._dispatch_error(device_handled_raw) return else: self.logger.error('error parsing dispatch message: %s', e) return self.logger.debug('dispatching message to different listeners: %s', raw) with self._lock: listeners = list(self._listeners) for l in listeners: self.logger.debug('dispatching message to listener: %r', l) l.callback(root, raw) # no try-except; fail loudly if you must! def _dispatch_error(self, err): with self._lock: listeners = list(self._listeners) for l in listeners: self.logger.debug('dispatching error to %r', l) try: # here we can be more considerate with catching exceptions l.errback(err) except Exception as e: self.logger.warning('error dispatching to %r: %r', l, e) def _post_connect(self): "Greeting stuff" init_event = Event() error = [None] # so that err_cb can bind error[0]. just how it is. # callbacks def ok_cb(id, capabilities): self._id = id self._server_capabilities = capabilities init_event.set() def err_cb(err): error[0] = err init_event.set() self.add_listener(NotificationHandler(self._notification_q)) listener = HelloHandler(ok_cb, err_cb) self.add_listener(listener) self.send( HelloHandler.build(self._client_capabilities, self._device_handler)) self.logger.debug('starting main loop') self.start() # we expect server's hello message, if server doesn't responds in 60 seconds raise exception init_event.wait(60) if not init_event.is_set(): raise SessionError("Capability exchange timed out") # received hello message or an error happened self.remove_listener(listener) if error[0]: raise error[0] #if ':base:1.0' not in self.server_capabilities: # raise MissingCapabilityError(':base:1.0') if 'urn:ietf:params:netconf:base:1.1' in self._server_capabilities and 'urn:ietf:params:netconf:base:1.1' in self._client_capabilities: self.logger.debug( "After 'hello' message selecting netconf:base:1.1 for encoding" ) self._base = NetconfBase.BASE_11 self.logger.info('initialized: session-id=%s | server_capabilities=%s', self._id, self._server_capabilities) def add_listener(self, listener): """Register a listener that will be notified of incoming messages and errors. :type listener: :class:`SessionListener` """ self.logger.debug('installing listener %r', listener) if not isinstance(listener, SessionListener): raise SessionError("Listener must be a SessionListener type") with self._lock: self._listeners.add(listener) def remove_listener(self, listener): """Unregister some listener; ignore if the listener was never registered. :type listener: :class:`SessionListener` """ self.logger.debug('discarding listener %r', listener) with self._lock: self._listeners.discard(listener) def get_listener_instance(self, cls): """If a listener of the specified type is registered, returns the instance. :type cls: :class:`SessionListener` """ with self._lock: for listener in self._listeners: if isinstance(listener, cls): return listener def connect(self, *args, **kwds): # subclass implements raise NotImplementedError def run(self): # subclass implements raise NotImplementedError def send(self, message): """Send the supplied *message* (xml string) to NETCONF server.""" if not self.connected: raise TransportError('Not connected to NETCONF server') self.logger.debug('queueing %s', message) self._q.put(message) def scp(self): raise NotImplementedError ### Properties def take_notification(self, block, timeout): try: return self._notification_q.get(block, timeout) except Empty: return None @property def connected(self): "Connection status of the session." return self._connected @property def client_capabilities(self): "Client's :class:`Capabilities`" return self._client_capabilities @property def server_capabilities(self): "Server's :class:`Capabilities`" return self._server_capabilities @property def id(self): """A string representing the `session-id`. If the session has not been initialized it will be `None`""" return self._id