class BaseLoader(object): name = None def __init__(self): self.logger = PrefixLoggerAdapter(logger, self.name) def find_class(self, module_name, base_cls, name): """ Load subclass of *base_cls* from module :param module_name: String containing module name :param base_cls: Base class :param name: object name :return: class reference or None """ try: sm = __import__(module_name, {}, {}, "*") for n in dir(sm): o = getattr(sm, n) if ( inspect.isclass(o) and issubclass(o, base_cls) and o.__module__ == sm.__name__ ): return o except ImportError as e: self.logger.error("Failed to load %s %s: %s", self.name, name, e) return None
class Engine(object): ILOCK = threading.Lock() AC_POLICY_VIOLATION = None def __init__(self, object): self.object = object self.logger = PrefixLoggerAdapter(logger, self.object.name) self.env = None self.templates = {} # fact class -> template self.fcls = {} # template -> Fact class self.facts = {} # Index -> Fact self.rn = 0 # Rule number self.config = None # Cached config self.interface_ranges = None with self.ILOCK: self.AC_POLICY_VIOLATION = AlarmClass.objects.filter( name="Config | Policy Violation").first() if not self.AC_POLICY_VIOLATION: logger.error( "Alarm class 'Config | Policy Violation' is not found. Alarms cannot be raised" ) def get_template(self, fact): if fact.cls not in self.templates: self.logger.debug("Creating template %s", fact.cls) self.templates[fact.cls] = self.env.BuildTemplate( fact.cls, fact.get_template()) self.fcls[fact.cls] = fact.__class__ self.logger.debug("Define template %s", self.templates[fact.cls].PPForm()) return self.templates[fact.cls] def get_rule_number(self): return self.rn def assert_fact(self, fact): f = self.get_template(fact).BuildFact() f.AssignSlotDefaults() for k, v in fact.iter_factitems(): if v is None or v == [] or v == tuple(): continue if isinstance(v, six.string_types): v = v.replace("\n", "\\n") f.Slots[k] = v try: f.Assert() except clips.ClipsError as e: self.logger.error("Could not assert: %s", f.PPForm()) self.logger.error("CLIPS Error: %s\n%s", e, clips.ErrorStream.Read()) return self.facts[f.Index] = fact self.logger.debug("Assert %s", f.PPForm()) def learn(self, gen): """ Learn sequence of facts """ n = 0 for f in gen: if hasattr(f, "managed_object") and f.managed_object is not None: f.bind() # @todo: Custom bindings from solutions self.assert_fact(f) n += 1 def iter_errors(self): """ Generator yielding known errors """ try: e = self.templates["error"].InitialFact() except TypeError: raise StopIteration while e: if "obj" in e.Slots.keys(): obj = e.Slots["obj"] if hasattr(obj, "Index"): # obj is a fact if obj.Index in self.facts: obj = self.facts[obj.Index] else: obj = None error = Error(e.Slots["type"], obj=obj, msg=e.Slots["msg"]) if e.Index not in self.facts: self.facts[e.Index] = error yield error e = e.Next() def iter_roles(self): """ Generator yielding role fact """ try: e = self.templates["role"].InitialFact() except TypeError: raise StopIteration while e: role = Error(e.Slots["name"]) if e.Index not in self.facts: self.facts[e.Index] = role yield role e = e.Next() def run(self): """ Run engine round :returns: Number of matched rules """ return self.env.Run() def add_rule(self, expr): self.env.Build(expr) self.rn += 1 def check(self): with CLIPSEnv() as env: self.setup_env(env) self._check() def _check(self): """ Perform object configuration check """ self.logger.info("Checking %s", self.object) parser = self.object.get_parser() self.config = self.object.config.read() if not self.config: self.logger.error("No config for %s. Giving up", self.object) return # Parse facts self.logger.debug("Parsing facts") facts = list(parser.parse(self.config)) self.logger.debug("%d facts are extracted", len(facts)) self.interface_ranges = parser.interface_ranges self.logger.debug("%d interface sections detected", len(self.interface_ranges)) # Define default templates self.get_template(Error(None)) self.get_template(Role(None)) # Learn facts self.logger.debug("Learning facts") self.learn(facts) self.logger.debug("Learning complete") # Install rules rules = [] for r in self.get_rules(): if r.is_applicable(): self.logger.debug("Using validation rule: %s", r.rule.name) try: cfg = r.get_config() r.prepare(**cfg) except clips.ClipsError as e: self.logger.error("CLIPS Error: %s\n%s", e, clips.ErrorStream.Read()) continue except Exception: error_report() continue rules += [(r, cfg)] # Run python validators for r, cfg in rules: r.check(**cfg) # Run CLIPS engine while True: self.logger.debug("Running engine") n = self.run() self.logger.debug("%d rules matched", n) break # @todo: Check for commands # Extract errors for e in self.iter_errors(): self.logger.info("Error found: %s", e) # Store object's facts self.sync_facts() # Manage related alarms if self.AC_POLICY_VIOLATION: self.sync_alarms() def _get_rule_settings(self, ps, scope): """ Process PolicySettings object and returns a list of (validator class, config) """ r = [] for pi in ps.policies: policy = pi.policy if not pi.is_active or not policy.is_active: continue for ri in policy.rules: if not ri.is_active: continue rule = ri.rule if rule.is_active and rule.is_applicable_for(self.object): vc = get_handler(rule.handler) if vc and bool(vc.SCOPE & scope): r += [(vc, rule)] return r def _get_rules(self, model, id, scope, obj=None): ps = ValidationPolicySettings.objects.filter( model_id=model, object_id=str(id)).first() if not ps or not ps.policies: return [] return [ vc(self, obj, rule.config, scope, rule) for vc, rule in self._get_rule_settings(ps, scope) ] def get_rules(self): r = [] # Object profile rules if self.object.object_profile: r += self._get_rules( "sa.ManagedObjectProfile", self.object.object_profile.id, BaseValidator.OBJECT, self.object, ) # Object rules r += self._get_rules("sa.ManagedObject", self.object.id, BaseValidator.OBJECT, self.object) # Interface rules profile_interfaces = defaultdict(list) for i in InvInterface.objects.filter(managed_object=self.object.id): if i.profile: profile_interfaces[i.profile] += [i] r += self._get_rules("inv.Interface", i.id, BaseValidator.INTERFACE, i) # Interface profile rules for p in profile_interfaces: ps = ValidationPolicySettings.objects.filter( model_id="inv.InterfaceProfile", object_id=str(p.id)).first() if not ps or not ps.policies: continue rs = self._get_rule_settings(ps, BaseValidator.INTERFACE) if rs: for iface in profile_interfaces[p]: r += [ vc(self, iface, rule.config, BaseValidator.INTERFACE, rule) for vc, rule in rs ] # Subinterface profile rules profile_subinterfaces = defaultdict(list) for si in InvSubInterface.objects.filter( managed_object=self.object.id): p = si.get_profile() if p: profile_subinterfaces[p] += [si] for p in profile_subinterfaces: ps = ValidationPolicySettings.objects.filter( model_id="inv.InterfaceProfile", object_id=str(p.id)).first() if not ps or not ps.policies: continue rs = self._get_rule_settings(ps, BaseValidator.SUBINTERFACE) if rs: for si in profile_subinterfaces[p]: r += [ vc(self, si, rule.config, BaseValidator.SUBINTERFACE, rule) for vc, rule in rs ] return r def get_fact_uuid(self, fact): r = [str(self.object.id), fact.cls ] + [str(getattr(fact, n)) for n in fact.ID] return uuid.uuid5(uuid.NAMESPACE_URL, "-".join(r)) def get_fact_attrs(self, fact): return dict(fact.iter_factitems()) def sync_facts(self): """ Retrieve known facts and synchronize with database """ self.logger.debug("Synchronizing facts") # Get facts from CLIPS self.logger.debug("Extracting facts") e_facts = {} # uuid -> fact try: f = self.env.InitialFact() except clips.ClipsError: return # No facts while f: if f.Template and f.Template.Name in self.templates: self.facts[f.Index] = f args = {} for k in f.Slots.keys(): v = f.Slots[k] if v == clips.Nil: v = None args[str(k)] = v fi = self.fcls[f.Template.Name](**args) e_facts[self.get_fact_uuid(fi)] = fi f = f.Next() # Get facts from database now = datetime.datetime.now() collection = ObjectFact._get_collection() bulk = [] new_facts = set(e_facts) for f in collection.find({"object": self.object.id}): if f["_id"] in e_facts: fact = e_facts[f["_id"]] f_attrs = self.get_fact_attrs(fact) if f_attrs != f["attrs"]: # Changed facts self.logger.debug("Fact %s has been changed: %s -> %s", f["_id"], f["attrs"], f_attrs) bulk += [ UpdateOne( {"_id": f["_id"]}, { "$set": { "attrs": f_attrs, "changed": now, "label": smart_text(fact) } }, ) ] new_facts.remove(f["_id"]) else: # Removed fact self.logger.debug("Fact %s has been removed", f["_id"]) bulk += [DeleteOne({"_id": f["_id"]})] # New facts for f in new_facts: fact = e_facts[f] f_attrs = self.get_fact_attrs(fact) self.logger.debug("Creating fact %s: %s", f, f_attrs) bulk += [ InsertOne({ "_id": f, "object": self.object.id, "cls": fact.cls, "label": smart_text(fact), "attrs": f_attrs, "introduced": now, "changed": now, }) ] if bulk: self.logger.debug("Commiting changes to database") try: collection.bulk_write(bulk) self.logger.debug("Database has been synced") except BulkWriteError as e: self.logger.error("Bulk write error: '%s'", e.details) self.logger.error("Stopping check") else: self.logger.debug("Nothing changed") def compile_query(self, **kwargs): def wrap(x): for k in kwargs: if getattr(x, k, None) != kwargs[k]: return False return True return wrap def find(self, **kwargs): """ Search facts for match. Returns a list of matching facts """ q = self.compile_query(**kwargs) return [f for f in six.itervalues(self.facts) if q(f)] def find_one(self, **kwargs): """ Search for first matching fact. Returns fact or None """ q = self.compile_query(**kwargs) for f in six.itervalues(self.facts): if q(f): return f return None def sync_alarms(self): """ Raise/close related alarms """ # Check errors are exists n_errors = sum(1 for e in self.iter_errors()) alarm = ActiveAlarm.objects.filter( alarm_class=self.AC_POLICY_VIOLATION.id, managed_object=self.object.id).first() if n_errors: if not alarm: self.logger.info("Raise alarm") # Raise alarm alarm = ActiveAlarm( timestamp=datetime.datetime.now(), managed_object=self.object, alarm_class=self.AC_POLICY_VIOLATION, severity=2000, # WARNING ) # Alarm is already exists alarm.log_message("%d errors has been found" % n_errors) elif alarm: # Clear alarm self.logger.info("Clear alarm") alarm.clear_alarm("No errors has been registered") def setup_env(self, env): """ Install additional CLIPS functions """ logger.debug("Setting up CLIPS environment") self.env = env # Create wrappers logger.debug("Install function: match-re") env.BuildFunction("match-re", "?rx ?s", "(return (python-call py-match-re ?rx ?s))")
class ProfileChecker(object): base_logger = logging.getLogger("profilechecker") _rules_cache = cachetools.TTLCache(10, ttl=60) _re_cache = {} def __init__( self, address=None, pool=None, logger=None, snmp_community=None, calling_service="profilechecker", snmp_version=None, ): self.address = address self.pool = pool self.logger = PrefixLoggerAdapter( logger or self.base_logger, "%s][%s" % (self.pool or "", self.address or "") ) self.result_cache = {} # (method, param) -> result self.error = None self.snmp_community = snmp_community self.calling_service = calling_service self.snmp_version = snmp_version or [SNMP_v2c] self.ignoring_snmp = False if self.snmp_version is None: self.logger.error("SNMP is not supported. Ignoring") self.ignoring_snmp = True if not self.snmp_community: self.logger.error("No SNMP credentials. Ignoring") self.ignoring_snmp = True def find_profile(self, method, param, result): """ Find profile by method :param method: Fingerprint getting method :param param: Method params :param result: Getting params result :return: """ r = defaultdict(list) d = self.get_rules() for k, value in sorted(six.iteritems(d), key=lambda x: x[0]): for v in value: r[v] += value[v] if (method, param) not in r: self.logger.warning("Not find rule for method: %s %s", method, param) return for match_method, value, action, profile, rname in r[(method, param)]: if self.is_match(result, match_method, value): self.logger.info("Matched profile: %s (%s)", profile, rname) # @todo: process MAYBE rule return profile def get_profile(self): """ Returns profile for object, or None when not known """ snmp_result = "" http_result = "" for ruleset in self.iter_rules(): for (method, param), actions in ruleset: try: result = self.do_check(method, param) if not result: continue if "snmp" in method: snmp_result = result if "http" in method: http_result = result for match_method, value, action, profile, rname in actions: if self.is_match(result, match_method, value): self.logger.info("Matched profile: %s (%s)", profile, rname) # @todo: process MAYBE rule return profile except NOCError as e: self.logger.error(e.message) self.error = str(e.message) return None if snmp_result or http_result: self.error = "Not find profile for OID: %s or HTTP string: %s" % ( snmp_result, http_result, ) elif not snmp_result: self.error = "Cannot fetch snmp data, check device for SNMP access" elif not http_result: self.error = "Cannot fetch HTTP data, check device for HTTP access" self.logger.info("Cannot detect profile: %s", self.error) return None def get_error(self): """ Get error message :return: """ return self.error @classmethod @cachetools.cachedmethod(operator.attrgetter("_rules_cache"), lock=lambda _: rules_lock) def get_profile_check_rules(cls): return list(ProfileCheckRule.objects.all().order_by("preference")) def get_rules(self): """ Load ProfileCheckRules and return a list, grouped by preferences [{ (method, param) -> [( match_method, value, action, profile, rule_name ), ...] }] """ self.logger.info('Compiling "Profile Check rules"') d = {} # preference -> (method, param) -> [rule, ..] for r in self.get_profile_check_rules(): if "snmp" in r.method and self.ignoring_snmp: continue if r.preference not in d: d[r.preference] = {} k = (r.method, r.param) if k not in d[r.preference]: d[r.preference][k] = [] d[r.preference][k] += [(r.match_method, r.value, r.action, r.profile, r.name)] return d def iter_rules(self): d = self.get_rules() for p in sorted(d): yield list(six.iteritems(d[p])) @classmethod @cachetools.cachedmethod(operator.attrgetter("_re_cache")) def get_re(cls, regexp): return re.compile(regexp) def do_check(self, method, param): """ Perform check """ self.logger.debug("do_check(%s, %s)", method, param) if (method, param) in self.result_cache: self.logger.debug("Using cached value") return self.result_cache[method, param] h = getattr(self, "check_%s" % method, None) if not h: self.logger.error("Invalid check method '%s'. Ignoring", method) return None result = h(param) self.result_cache[method, param] = result return result def check_snmp_v2c_get(self, param): """ Perform SNMP v2c GET. Param is OID or symbolic name """ try: param = mib[param] except KeyError: self.logger.error("Cannot resolve OID '%s'. Ignoring", param) return None for v in self.snmp_version: if v == SNMP_v1: r = self.snmp_v1_get(param) elif v == SNMP_v2c: r = self.snmp_v2c_get(param) else: raise NOCError(msg="Unsupported SNMP version") if r: return r def check_http_get(self, param): """ Perform HTTP GET check. Param can be URL path or :<port>/<path> """ url = "http://%s%s" % (self.address, param) return self.http_get(url) def check_https_get(self, param): """ Perform HTTPS GET check. Param can be URL path or :<port>/<path> """ url = "https://%s%s" % (self.address, param) return self.https_get(url) def is_match(self, result, method, value): """ Returns True when result matches value """ if method == "eq": return result == value elif method == "contains": return value in result elif method == "re": return bool(self.get_re(value).search(result)) else: self.logger.error("Invalid match method '%s'. Ignoring", method) return False def snmp_v1_get(self, param): """ Perform SNMP v1 request. May be overridden for testing :param param: :return: """ self.logger.info("SNMP v1 GET: %s", param) try: return open_sync_rpc( "activator", pool=self.pool, calling_service=self.calling_service ).snmp_v1_get(self.address, self.snmp_community, param) except RPCError as e: self.logger.error("RPC Error: %s", e) return None def snmp_v2c_get(self, param): """ Perform SNMP v2c request. May be overridden for testing :param param: :return: """ self.logger.info("SNMP v2c GET: %s", param) try: return open_sync_rpc( "activator", pool=self.pool, calling_service=self.calling_service ).snmp_v2c_get(self.address, self.snmp_community, param) except RPCError as e: self.logger.error("RPC Error: %s", e) return None def http_get(self, url): """ Perform HTTP request. May be overridden for testing :param url: Request URL :return: """ self.logger.info("HTTP Request: %s", url) try: return open_sync_rpc( "activator", pool=self.pool, calling_service=self.calling_service ).http_get(url, True) except RPCError as e: self.logger.error("RPC Error: %s", e) return None def https_get(self, url): """ Perform HTTP request. May be overridden for testing :param url: Request URL :return: """ return self.http_get(url)
class BaseScript(object, metaclass=BaseScriptMetaclass): """ Service Activation script base class """ # Script name in form of <vendor>.<system>.<name> name = None # Default script timeout TIMEOUT = config.script.timeout # Default session timeout SESSION_IDLE_TIMEOUT = config.script.session_idle_timeout # Default access preferene DEFAULT_ACCESS_PREFERENCE = "SC" # Enable call cache # If True, script result will be cached and reused # during lifetime of parent script cache = False # Implemented interface interface = None # Scripts required by generic script. # For common scripts - empty list # For generics - list of pairs (script_name, interface) requires = [] # base_logger = logging.getLogger(name or "script") # _x_seq = itertools.count() # Sessions cli_session_store = SessionStore() mml_session_store = SessionStore() rtsp_session_store = SessionStore() # In session mode when active CLI session exists # * True -- reuse session # * False -- close session and run new without session context reuse_cli_session = True # In session mode: # Should we keep CLI session for reuse by next script # * True - keep CLI session for next script # * False - close CLI session keep_cli_session = True # Script-level matchers. # Override profile one matchers = {} # Error classes shortcuts ScriptError = ScriptError CLISyntaxError = CLISyntaxError CLIOperationError = CLIOperationError NotSupportedError = NotSupportedError UnexpectedResultError = UnexpectedResultError hexbin = { "0": "0000", "1": "0001", "2": "0010", "3": "0011", "4": "0100", "5": "0101", "6": "0110", "7": "0111", "8": "1000", "9": "1001", "a": "1010", "b": "1011", "c": "1100", "d": "1101", "e": "1110", "f": "1111", } cli_protocols = { "telnet": "noc.core.script.cli.telnet.TelnetCLI", "ssh": "noc.core.script.cli.ssh.SSHCLI", "beef": "noc.core.script.cli.beef.BeefCLI", } mml_protocols = {"telnet": "noc.core.script.mml.telnet.TelnetMML"} rtsp_protocols = {"tcp": "noc.core.script.rtsp.base.RTSPBase"} # Override access preferences for script # S - always try SNMP first # C - always try CLI first # None - use default preferences always_prefer = None def __init__( self, service, credentials, args=None, capabilities=None, version=None, parent=None, timeout=None, name=None, session=None, session_idle_timeout=None, ): self.service = service self.tos = config.activator.tos self.pool = config.pool self.parent = parent self._motd = None name = name or self.name self.logger = PrefixLoggerAdapter( self.base_logger, "%s] [%s" % (self.name, credentials.get("address", "-"))) if self.parent: self.profile = self.parent.profile else: self.profile = profile_loader.get_profile(".".join( name.split(".")[:2]))() self.credentials = credentials or {} if self.is_beefed: self.credentials["snmp_ro"] = "public" # For core.snmp.base check self.version = version or {} self.capabilities = capabilities or {} self.timeout = timeout or self.get_timeout() self.start_time = None self._interface = self.interface() self.args = self.clean_input(args) if args else {} self.cli_stream = None self.mml_stream = None self.rtsp_stream = None self._snmp: Optional[SNMP] = None self._http: Optional[HTTP] = None self.to_disable_pager = not self.parent and self.profile.command_disable_pager self.scripts = ScriptsHub(self) # Store session id self.session = session self.session_idle_timeout = session_idle_timeout or self.SESSION_IDLE_TIMEOUT # Cache CLI and SNMP calls, if set self.is_cached = False # Suitable only when self.parent is None. # Cached results for scripts marked with "cache" self.call_cache = {} # Suitable only when self.parent is None # Cached results of self.cli calls self.cli_cache = {} # self.http_cache = {} self.partial_result = None # @todo: Get native encoding from ManagedObject self.native_encoding = "utf8" # Tracking self.to_track = False self.cli_tracked_data = {} # command -> [packets] self.cli_tracked_command = None # state -> [..] self.cli_fsm_tracked_data = {} # if not parent and version and not name.endswith(".get_version"): self.logger.debug("Filling get_version cache with %s", version) s = name.split(".") self.set_cache("%s.%s.get_version" % (s[0], s[1]), {}, version) if (self.is_beefed and not parent and not name.endswith(".get_capabilities") and not name.endswith(".get_version")): self.capabilities = self.scripts.get_capabilities() self.logger.info("Filling capabilities with %s", self.capabilities) # Fill matchers if not self.name.endswith(".get_version"): self.apply_matchers() # if self.profile.setup_script: self.profile.setup_script(self) def __call__(self, *args, **kwargs): self.args = kwargs return self.run() @property def snmp(self) -> SNMP: if not self._snmp: if self.parent: self._snmp = self.root.snmp elif self.is_beefed: self._snmp = BeefSNMP(self) else: snmp_rate_limit = self.credentials.get("snmp_rate_limit", None) or None if snmp_rate_limit is None: snmp_rate_limit = self.profile.get_snmp_rate_limit(self) self._snmp = SNMP(self, rate=snmp_rate_limit) return self._snmp @property def http(self) -> HTTP: if not self._http: if self.parent: self._http = self.root.http else: self._http = HTTP(self) return self._http def apply_matchers(self): """ Process matchers and apply is_XXX properties :return: """ def get_matchers(c, matchers): return {m: match(c, matchers[m]) for m in matchers} # Match context # @todo: Add capabilities ctx = self.version or {} if self.capabilities: ctx["caps"] = self.capabilities # Calculate matches v = get_matchers(ctx, self.profile.matchers) v.update(get_matchers(ctx, self.matchers)) # for k in v: self.logger.debug("%s = %s", k, v[k]) setattr(self, k, v[k]) def clean_input(self, args): """ Cleanup input parameters against interface """ return self._interface.script_clean_input(self.profile, **args) def clean_output(self, result): """ Clean script result against interface """ return self._interface.script_clean_result(self.profile, result) def run(self): """ Run script """ with Span(server="activator", service=self.name, in_label=self.credentials.get("address")): self.start_time = perf_counter() self.logger.debug("Running. Input arguments: %s, timeout %s", self.args, self.timeout) # Use cached result when available cache_hit = False if self.cache and self.parent: try: result = self.get_cache(self.name, self.args) self.logger.info("Using cached result") cache_hit = True except KeyError: pass # Execute script if not cache_hit: try: result = self.execute(**self.args) if self.cache and self.parent and result: self.logger.info("Caching result") self.set_cache(self.name, self.args, result) finally: if not self.parent: # Close SNMP socket when necessary self.close_snmp() # Close CLI socket when necessary self.close_cli_stream() # Close MML socket when necessary self.close_mml_stream() # Close RTSP socket when necessary self.close_rtsp_stream() # Close HTTP Client self.http.close() # Clean result result = self.clean_output(result) self.logger.debug("Result: %s", result) runtime = perf_counter() - self.start_time self.logger.info("Complete (%.2fms)", runtime * 1000) return result @classmethod def compile_match_filter(cls, *args, **kwargs): # pylint: disable=undefined-variable """ Compile arguments into version check function Returns callable accepting self and version hash arguments """ c = [lambda self, x, g=f: g(x) for f in args] for k, v in kwargs.items(): # Split to field name and lookup operator if "__" in k: f, o = k.split("__") else: f = k o = "exact" # Check field name if f not in ("vendor", "platform", "version", "image"): raise Exception("Invalid field '%s'" % f) # Compile lookup functions if o == "exact": c += [lambda self, x, f=f, v=v: x[f] == v] elif o == "iexact": c += [lambda self, x, f=f, v=v: x[f].lower() == v.lower()] elif o == "startswith": c += [lambda self, x, f=f, v=v: x[f].startswith(v)] elif o == "istartswith": c += [ lambda self, x, f=f, v=v: x[f].lower().startswith(v.lower( )) ] elif o == "endswith": c += [lambda self, x, f=f, v=v: x[f].endswith(v)] elif o == "iendswith": c += [ lambda self, x, f=f, v=v: x[f].lower().endswith(v.lower()) ] elif o == "contains": c += [lambda self, x, f=f, v=v: v in x[f]] elif o == "icontains": c += [lambda self, x, f=f, v=v: v.lower() in x[f].lower()] elif o == "in": c += [lambda self, x, f=f, v=v: x[f] in v] elif o == "regex": c += [ lambda self, x, f=f, v=re.compile(v): v.search(x[f]) is not None ] elif o == "iregex": c += [ lambda self, x, f=f, v=re.compile(v, re.IGNORECASE): v. search(x[f]) is not None ] elif o == "isempty": # Empty string or null c += [lambda self, x, f=f, v=v: not x[f] if v else x[f]] elif f == "version": if o == "lt": # < c += [ lambda self, x, v=v: self.profile.cmp_version( x["version"], v) < 0 ] elif o == "lte": # <= c += [ lambda self, x, v=v: self.profile.cmp_version( x["version"], v) <= 0 ] elif o == "gt": # > c += [ lambda self, x, v=v: self.profile.cmp_version( x["version"], v) > 0 ] elif o == "gte": # >= c += [ lambda self, x, v=v: self.profile.cmp_version( x["version"], v) >= 0 ] else: raise Exception("Invalid lookup operation: %s" % o) else: raise Exception("Invalid lookup operation: %s" % o) # Combine expressions into single lambda return reduce( lambda x, y: lambda self, v, x=x, y=y: (x(self, v) and y(self, v)), # pylint: disable=undefined-variable c, lambda self, x: True, ) @classmethod def match(cls, *args, **kwargs): """ execute method decorator """ def wrap(f): # Append to the execute chain if hasattr(f, "_match"): old_filter = f._match # pylint: disable=undefined-variable f._match = lambda self, v, old_filter=old_filter, new_filter=new_filter: new_filter( self, v) or old_filter(self, v) else: f._match = new_filter f._seq = next(cls._x_seq) return f # Compile check function new_filter = cls.compile_match_filter(*args, **kwargs) # Return decorated function return wrap def match_version(self, *args, **kwargs): """ inline version for BaseScript.match """ if not self.version: self.version = self.scripts.get_version() return self.compile_match_filter(*args, **kwargs)(self, self.version) def execute(self, **kwargs): """ Default script behavior: Pass through _execute_chain and call appropriate handler """ if self._execute_chain and not self.name.endswith(".get_version"): # Deprecated @match chain self.logger.info( "WARNING: Using deprecated @BaseScript.match() decorator. " "Consider porting to the new matcher API") # Get version information if not self.version: self.version = self.scripts.get_version() # Find and execute proper handler for f in self._execute_chain: if f._match(self, self.version): return f(self, **kwargs) # Raise error raise self.NotSupportedError() else: # New SNMP/CLI API return self.call_method(cli_handler=self.execute_cli, snmp_handler=self.execute_snmp, **kwargs) def call_method(self, cli_handler=None, snmp_handler=None, fallback_handler=None, **kwargs): """ Call function depending on access_preference :param cli_handler: String or callable to call on CLI access method :param snmp_handler: String or callable to call on SNMP access method :param fallback_handler: String or callable to call if no access method matched :param kwargs: :return: """ # Select proper handler access_preference = self.get_access_preference() + "*" for m in access_preference: # Select proper handler if m == "C": handler = cli_handler elif m == "S": if self.has_snmp() or self.name.endswith(".get_version"): handler = snmp_handler else: self.logger.debug( "SNMP is not enabled. Passing to next method") continue elif m == "*": handler = fallback_handler else: raise self.NotSupportedError("Invalid access method '%s'" % m) # Resolve handler when necessary if isinstance(handler, str): handler = getattr(self, handler, None) if handler is None: self.logger.debug("No '%s' handler. Passing to next method" % m) continue # Call handler try: r = handler(**kwargs) if isinstance(r, PartialResult): if self.partial_result: self.partial_result.update(r.result) else: self.partial_result = r.result self.logger.debug( "Partial result: %r. Passing to next method", self.partial_result) else: return r except self.snmp.TimeOutError: self.logger.info("SNMP timeout. Passing to next method") if access_preference == "S*": self.logger.info("Last S method break by timeout.") raise self.snmp.TimeOutError except NotImplementedError: self.logger.debug( "Access method '%s' is not implemented. Passing to next method", m) raise self.NotSupportedError( "Access preference '%s' is not supported" % access_preference[:-1]) def execute_cli(self, **kwargs): """ Process script using CLI :param kwargs: :return: """ raise NotImplementedError("execute_cli() is not implemented") def execute_snmp(self, **kwargs): """ Process script using SNMP :param kwargs: :return: """ raise NotImplementedError("execute_snmp() is not implemented") def cleaned_config(self, config): """ Clean up config from all unnecessary trash """ return self.profile.cleaned_config(config) def strip_first_lines(self, text, lines=1): """ Strip first *lines* """ t = text.split("\n") if len(t) <= lines: return "" else: return "\n".join(t[lines:]) def expand_rangelist(self, s): """ Expand expressions like "1,2,5-7" to [1, 2, 5, 6, 7] """ result = {} for x in s.split(","): x = x.strip() if x == "": continue if "-" in x: left, right = [int(y) for y in x.split("-")] if left > right: x = right right = left left = x for i in range(left, right + 1): result[i] = None else: result[int(x)] = None return sorted(result.keys()) rx_detect_sep = re.compile(r"^(.*?)\d+$") def expand_interface_range(self, s): """ Convert interface range expression to a list of interfaces "Gi 1/1-3,Gi 1/7" -> ["Gi 1/1", "Gi 1/2", "Gi 1/3", "Gi 1/7"] "1:1-3" -> ["1:1", "1:2", "1:3"] "1:1-1:3" -> ["1:1", "1:2", "1:3"] :param s: Comma-separated list :return: """ r = set() for x in s.split(","): x = x.strip() if not x: continue if "-" in x: # Expand range f, t = [y.strip() for y in x.split("-")] # Detect common prefix match = self.rx_detect_sep.match(f) if not match: raise ValueError(x) prefix = match.group(1) # Detect range boundaries start = int(f[len(prefix):]) if is_int(t): stop = int(t) # Just integer else: if not t.startswith(prefix): raise ValueError(x) stop = int(t[len(prefix):]) # Prefixed if start > stop: raise ValueError(x) for i in range(start, stop + 1): r.add(prefix + str(i)) else: r.add(x) return sorted(r) def macs_to_ranges(self, macs): """ Converts list of macs to rangea :param macs: Iterable yielding mac addresses :returns: [(from, to), ..] """ r = [] for m in sorted(MAC(x) for x in macs): if r: if r[-1][1].shift(1) == m: # Expand last range r[-1][1] = m else: r += [[m, m]] else: r += [[m, m]] return [(str(x[0]), str(x[1])) for x in r] def hexstring_to_mac(self, s): """Convert a 6-octet string to MAC address""" return ":".join(["%02X" % ord(x) for x in s]) @property def root(self): """Get root script""" if self.parent: return self.parent.root else: return self def get_cache(self, key1, key2): """Get cached result or raise KeyError""" s = self.root return s.call_cache[repr(key1)][repr(key2)] def set_cache(self, key1, key2, value): """Set cached result""" key1 = repr(key1) key2 = repr(key2) s = self.root if key1 not in s.call_cache: s.call_cache[key1] = {} s.call_cache[key1][key2] = value def configure(self): """Returns configuration context""" return ConfigurationContextManager(self) def cached(self): """ Return cached context managed. All nested CLI and SNMP GET/GETNEXT calls will be cached. Usage: with self.cached(): self.cli(".....) self.scripts.script() """ return CacheContextManager(self) def enter_config(self): """Enter configuration mote""" if self.profile.command_enter_config: self.cli(self.profile.command_enter_config) def leave_config(self): """Leave configuration mode""" if self.profile.command_leave_config: self.cli(self.profile.command_leave_config) self.cli( "" ) # Guardian empty command to wait until configuration is finally written def save_config(self, immediately=False): """Save current config""" if immediately: if self.profile.command_save_config: self.cli(self.profile.command_save_config) else: self.schedule_to_save() def schedule_to_save(self): self.need_to_save = True if self.parent: self.parent.schedule_to_save() def set_motd(self, motd): self._motd = motd @property def motd(self): """ Return message of the day """ if self._motd: return self._motd return self.get_cli_stream().get_motd() def re_search(self, rx, s, flags=0): """ Match s against regular expression rx using re.search Raise UnexpectedResultError if regular expression is not matched. Returns match object. rx can be string or compiled regular expression """ if isinstance(rx, str): rx = re.compile(rx, flags) match = rx.search(s) if match is None: raise UnexpectedResultError() return match def re_match(self, rx, s, flags=0): """ Match s against regular expression rx using re.match Raise UnexpectedResultError if regular expression is not matched. Returns match object. rx can be string or compiled regular expression """ if isinstance(rx, str): rx = re.compile(rx, flags) match = rx.match(s) if match is None: raise UnexpectedResultError() return match _match_lines_cache = {} @classmethod def match_lines(cls, rx, s): k = id(rx) if k not in cls._match_lines_cache: _rx = [re.compile(line, re.IGNORECASE) for line in rx] cls._match_lines_cache[k] = _rx else: _rx = cls._match_lines_cache[k] ctx = {} idx = 0 r = _rx[0] for line in s.splitlines(): line = line.strip() match = r.search(line) if match: ctx.update(match.groupdict()) idx += 1 if idx == len(_rx): return ctx r = _rx[idx] return None def find_re(self, iter, s): """ Find first matching regular expression or raise Unexpected result error """ for r in iter: if r.search(s): return r raise UnexpectedResultError() def hex_to_bin(self, s): """ Convert hexadecimal string to boolean string. All non-hexadecimal characters are ignored :param s: Input string :return: Boolean string :rtype: str """ return "".join(self.hexbin[c] for c in "".join("%02x" % ord(d) for d in s)) def push_prompt_pattern(self, pattern): self.get_cli_stream().push_prompt_pattern(pattern) def pop_prompt_pattern(self): self.get_cli_stream().pop_prompt_pattern() def has_oid(self, oid): """ Check object responses to oid """ try: return bool(self.snmp.get(oid)) except self.snmp.TimeOutError: return False def get_timeout(self): return self.TIMEOUT def cli( self, cmd: str, command_submit: Optional[bytes] = None, bulk_lines: Any = None, list_re: Any = None, cached: bool = False, file: Optional[str] = None, ignore_errors: Any = False, allow_empty_response: Any = True, nowait: Any = False, obj_parser: Any = None, cmd_next: Any = None, cmd_stop: Any = None, ) -> str: """ Execute CLI command and return result. Initiate cli session when necessary. if list_re is None, return a string if list_re is regular expression object, return a list of dicts (group name -> value), one dict per matched line :param cmd: CLI command to execute :param command_submit: Optional suffix to submit command. Profile's one used by default :param bulk_lines: :param list_re: :param cached: True if result of execution may be cached :param file: Path to the file containing debugging result :param ignore_errors: :param allow_empty_response: Allow empty output. If False - ignore prompt and wait output :param nowait: """ def format_result(result): if list_re: x = [] for line in result.splitlines(): match = list_re.match(line.strip()) if match: x += [match.groupdict()] return x else: return result if file: # Read from file with open(file) as f: return format_result(f.read()) if cached: # Cached result r = self.root.cli_cache.get(cmd) if r is not None: self.logger.debug("Use cached result") return format_result(r) # Effective command submit suffix if command_submit is None: command_submit = self.profile.command_submit # Encode submitted command submitted_cmd = smart_bytes( cmd, encoding=self.native_encoding) + command_submit # Run command stream = self.get_cli_stream() if self.to_track: self.cli_tracked_command = cmd r = stream.execute( submitted_cmd, obj_parser=obj_parser, cmd_next=cmd_next, cmd_stop=cmd_stop, ignore_errors=ignore_errors, allow_empty_response=allow_empty_response, ) if isinstance(r, bytes): r = smart_text(r, errors="ignore", encoding=self.native_encoding) if isinstance(r, str): # Check for syntax errors if not ignore_errors: # Then check for operation error if (self.profile.rx_pattern_operation_error and self.profile.rx_pattern_operation_error_str.search(r)): raise self.CLIOperationError(r) # Echo cancelation r = self.echo_cancelation(r, cmd) # Store cli cache when necessary if cached: self.root.cli_cache[cmd] = r return format_result(r) def echo_cancelation(self, r: str, cmd: str) -> str: """ Adaptive echo cancelation :param r: :param cmd: :return: """ if r[:4096].lstrip().startswith(cmd): r = r.lstrip() if r.startswith(cmd + "\n"): # Remove first line r = self.strip_first_lines(r.lstrip()) else: # Some switches, like ProCurve do not send \n after the echo r = r[len(cmd):] return r def get_cli_stream(self): if self.parent: return self.root.get_cli_stream() if not self.cli_stream and self.session: # Try to get cached session's CLI self.cli_stream = self.cli_session_store.get(self.session) if self.cli_stream: if self.to_reuse_cli_session(): self.logger.debug("Using cached session's CLI") self.cli_stream.set_script(self) else: self.logger.debug( "Script cannot reuse existing CLI session, starting new one" ) self.close_cli_stream() if not self.cli_stream: protocol = self.credentials.get("cli_protocol", "telnet") self.logger.debug("Open %s CLI", protocol) self.cli_stream = get_handler(self.cli_protocols[protocol])( self, tos=self.tos) # Store to the sessions if self.session: self.cli_session_store.put(self.session, self.cli_stream) self.cli_stream.setup_session() # Disable pager when necessary # @todo: Move to CLI if self.to_disable_pager: self.logger.debug("Disable paging") self.to_disable_pager = False if isinstance(self.profile.command_disable_pager, str): self.cli(self.profile.command_disable_pager, ignore_errors=True) elif isinstance(self.profile.command_disable_pager, list): for cmd in self.profile.command_disable_pager: self.cli(cmd, ignore_errors=True) else: raise UnexpectedResultError return self.cli_stream def close_cli_stream(self): if self.parent: return if self.cli_stream: if self.session and self.to_keep_cli_session(): # Return cli stream to pool self.cli_session_store.put(self.session, self.cli_stream, self.session_idle_timeout) else: self.cli_stream.shutdown_session() self.cli_stream.close() self.cli_stream = None def close_snmp(self): if self.parent: return if self._snmp: self._snmp.close() self._snmp = None def mml(self, cmd, **kwargs): """ Execute MML command and return result. Initiate MML session when necessary :param cmd: :param kwargs: :return: """ stream = self.get_mml_stream() r = stream.execute(cmd, **kwargs) return r def get_mml_stream(self): if self.parent: return self.root.get_mml_stream() if not self.mml_stream and self.session: # Try to get cached session's CLI self.mml_stream = self.mml_session_store.get(self.session) if self.mml_stream: if self.to_reuse_cli_session(): self.logger.debug("Using cached session's MML") self.mml_stream.set_script(self) else: self.logger.debug( "Script cannot reuse existing MML session, starting new one" ) self.close_mml_stream() if not self.mml_stream: protocol = self.credentials.get("cli_protocol", "telnet") self.logger.debug("Open %s MML", protocol) self.mml_stream = get_handler(self.mml_protocols[protocol])( self, tos=self.tos) # Store to the sessions if self.session: self.mml_session_store.put(self.session, self.mml_stream) return self.mml_stream def close_mml_stream(self): if self.parent: return if self.mml_stream: if self.session and self.to_keep_cli_session(): self.mml_session_store.put(self.session, self.mml_stream, self.session_idle_timeout) else: self.mml_stream.close() self.cli_stream = None def rtsp(self, method, path, **kwargs): """ Execute RTSP command and return result. Initiate RTSP session when necessary :param method: :param path: :param kwargs: :return: """ stream = self.get_rtsp_stream() r = stream.execute(path, method, **kwargs) return r def get_rtsp_stream(self): if self.parent: return self.root.get_rtsp_stream() if not self.rtsp_stream and self.session: # Try to get cached session's CLI self.rtsp_stream = self.rtsp_session_store.get(self.session) if self.rtsp_stream: if self.to_reuse_cli_session(): self.logger.debug("Using cached session's RTSP") self.rtsp_stream.set_script(self) else: self.logger.debug( "Script cannot reuse existing RTSP session, starting new one" ) self.close_rtsp_stream() if not self.rtsp_stream: protocol = "tcp" self.logger.debug("Open %s RTSP", protocol) self.rtsp_stream = get_handler(self.rtsp_protocols[protocol])( self, tos=self.tos) # Store to the sessions if self.session: self.rtsp_session_store.put(self.session, self.rtsp_stream) return self.rtsp_stream def close_rtsp_stream(self): if self.parent: return if self.rtsp_stream: if self.session and self.to_keep_cli_session(): self.rtsp_session_store.put(self.session, self.rtsp_stream, self.session_idle_timeout) else: self.rtsp_stream.close() self.cli_stream = None def close_current_session(self): if self.session: self.close_session(self.session) @classmethod def close_session(cls, session_id): """ Explicit session closing :return: """ cls.cli_session_store.remove(session_id, shutdown=True) cls.mml_session_store.remove(session_id, shutdown=True) cls.rtsp_session_store.remove(session_id, shutdown=True) def get_access_preference(self): preferred = self.get_always_preferred() r = self.credentials.get("access_preference", self.DEFAULT_ACCESS_PREFERENCE) if preferred and preferred in r: return preferred + "".join(x for x in r if x != preferred) return r def get_always_preferred(self): """ Return always preferred access method :return: """ return self.always_prefer def has_cli_access(self): return "C" in self.get_access_preference() def has_snmp_access(self): return "S" in self.get_access_preference() and self.has_snmp() def has_cli_only_access(self): return self.has_cli_access() and not self.has_snmp_access() def has_snmp_only_access(self): return not self.has_cli_access() and self.has_snmp_access() def has_snmp(self): """ Check whether equipment has SNMP enabled """ if self.has_capability("SNMP", allow_zero=True): # If having SNMP caps - check it and credential return bool(self.credentials.get( "snmp_ro")) and self.has_capability("SNMP") else: # if SNMP caps not exist check credential return bool(self.credentials.get("snmp_ro")) def has_snmp_v1(self): return self.has_capability("SNMP | v1") def has_snmp_v2c(self): return self.has_capability("SNMP | v2c") def has_snmp_v3(self): return self.has_capability("SNMP | v3") def has_snmp_bulk(self): """ Check whether equipment supports SNMP BULK """ return self.has_capability("SNMP | Bulk") def has_capability(self, capability, allow_zero=False): """ Check whether equipment supports capability """ if allow_zero: return self.capabilities.get(capability) is not None else: return bool(self.capabilities.get(capability)) def ignored_exceptions(self, iterable): """ Context manager to silently ignore specified exceptions """ return IgnoredExceptionsContextManager(iterable) def iter_pairs(self, g, offset=0): """ Convert iterable g to a pairs i.e. [1, 2, 3, 4] -> [(1, 2), (3, 4)] :param g: Iterable :param offset: Skip first recirds :return: """ g = iter(g) if offset: for _ in range(offset): next(g) return zip(g, g) def to_reuse_cli_session(self): return self.reuse_cli_session def to_keep_cli_session(self): return self.keep_cli_session def start_tracking(self): self.logger.debug("Start tracking") self.to_track = True def stop_tracking(self): self.logger.debug("Stop tracking") self.to_track = False self.cli_tracked_data = {} def push_cli_tracking(self, r, state): if state == "prompt": if self.cli_tracked_command in self.cli_tracked_data: self.cli_tracked_data[self.cli_tracked_command] += [r] else: self.cli_tracked_data[self.cli_tracked_command] = [r] elif state in self.cli_fsm_tracked_data: self.cli_fsm_tracked_data[state] += [r] else: self.cli_fsm_tracked_data[state] = [r] def push_snmp_tracking(self, oid, tlv): self.logger.debug("PUSH SNMP %s: %r", oid, tlv) def iter_cli_tracking(self): """ Yields command, packets for collected data :return: """ for cmd in self.cli_tracked_data: self.logger.debug("Collecting %d tracked CLI items", len(self.cli_tracked_data[cmd])) yield cmd, self.cli_tracked_data[cmd] self.cli_tracked_data = {} def iter_cli_fsm_tracking(self): for state in self.cli_fsm_tracked_data: yield state, self.cli_fsm_tracked_data[state] def request_beef(self): """ Download and return beef :return: """ if not hasattr(self, "_beef"): self.logger.debug("Requesting beef") beef_storage_url = self.credentials.get("beef_storage_url") beef_path = self.credentials.get("beef_path") if not beef_storage_url: self.logger.debug("No storage URL") self._beef = None return None if not beef_path: self.logger.debug("No beef path") self._beef = None return None from .beef import Beef try: beef = Beef.load(beef_storage_url, beef_path) except IOError as e: self.logger.error("Beef load error: %s", e) return None self._beef = beef return self._beef @property def is_beefed(self): return self.credentials.get("cli_protocol") == "beef"