class BaseController: log_severity = {"error": error, "info": info, "warning": warning} json_endpoints = [ "multiselect_filtering", "save_settings", "table_filtering", "view_filtering", ] rest_endpoints = [ "get_cluster_status", "get_git_content", "update_all_pools", "update_database_configurations_from_git", ] def __init__(self): self.settings = settings self.rbac = rbac self.properties = properties self.load_custom_properties() self.path = Path.cwd() self.init_scheduler() if settings["tacacs"]["active"]: self.init_tacacs_client() if settings["ldap"]["active"]: self.init_ldap_client() if settings["vault"]["active"]: self.init_vault_client() if settings["syslog"]["active"]: self.init_syslog_server() if settings["paths"]["custom_code"]: sys_path.append(settings["paths"]["custom_code"]) self.fetch_version() self.init_logs() self.init_connection_pools() def configure_database(self): self.init_services() Base.metadata.create_all(bind=engine) configure_mappers() configure_events(self) self.init_forms() self.clean_database() if not fetch("user", allow_none=True, name="admin"): self.configure_server_id() self.create_admin_user() Session.commit() if self.settings["app"]["create_examples"]: self.migration_import(name="examples", import_export_types=import_classes) self.update_credentials() else: self.migration_import(name="default", import_export_types=import_classes) self.get_git_content() Session.commit() def clean_database(self): for run in fetch("run", all_matches=True, allow_none=True, status="Running"): run.status = "Aborted (app reload)" Session.commit() def fetch_version(self): with open(self.path / "package.json") as package_file: self.version = load(package_file)["version"] def configure_server_id(self): factory( "server", **{ "name": str(getnode()), "description": "Localhost", "ip_address": "0.0.0.0", "status": "Up", }, ) def create_admin_user(self) -> None: admin = factory("user", **{"name": "admin", "group": "Admin"}) if not admin.password: admin.update(password="******") def update_credentials(self): with open(self.path / "files" / "spreadsheets" / "usa.xls", "rb") as file: self.topology_import(file) def get_git_content(self): repo = self.settings["app"]["git_repository"] if not repo: return local_path = self.path / "network_data" try: if exists(local_path): Repo(local_path).remotes.origin.pull() else: local_path.mkdir(parents=True, exist_ok=True) Repo.clone_from(repo, local_path) except Exception as exc: self.log("error", f"Git pull failed ({str(exc)})") self.update_database_configurations_from_git() def load_custom_properties(self): for model, values in self.properties["custom"].items(): property_names.update( {k: v["pretty_name"] for k, v in values.items()}) model_properties[model].extend(list(values)) private_properties.extend( list(p for p, v in values.items() if v.get("private", False))) def init_logs(self): log_level = self.settings["app"]["log_level"].upper() folder = self.path / "logs" folder.mkdir(parents=True, exist_ok=True) basicConfig( level=getattr(import_module("logging"), log_level), format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%m-%d-%Y %H:%M:%S", handlers=[ RotatingFileHandler(folder / "enms.log", maxBytes=20_000_000, backupCount=10), StreamHandler(), ], ) def init_connection_pools(self): self.request_session = RequestSession() retry = Retry(**self.settings["requests"]["retries"]) for protocol in ("http", "https"): self.request_session.mount( f"{protocol}://", HTTPAdapter( max_retries=retry, **self.settings["requests"]["pool"], ), ) def init_scheduler(self): self.scheduler = BackgroundScheduler({ "apscheduler.jobstores.default": { "type": "sqlalchemy", "url": "sqlite:///jobs.sqlite", }, "apscheduler.executors.default": { "class": "apscheduler.executors.pool:ThreadPoolExecutor", "max_workers": "50", }, "apscheduler.job_defaults.misfire_grace_time": "5", "apscheduler.job_defaults.coalesce": "true", "apscheduler.job_defaults.max_instances": "3", }) self.scheduler.start() def init_forms(self): for file in (self.path / "eNMS" / "forms").glob("**/*.py"): spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) spec.loader.exec_module(module_from_spec(spec)) def init_services(self): path_services = [self.path / "eNMS" / "services"] if self.settings["paths"]["custom_services"]: path_services.append( Path(self.settings["paths"]["custom_services"])) for path in path_services: for file in path.glob("**/*.py"): if "init" in str(file): continue if not self.settings["app"][ "create_examples"] and "examples" in str(file): continue info(f"Loading service: {file}") spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) try: spec.loader.exec_module(module_from_spec(spec)) except InvalidRequestError as e: error(f"Error loading custom service '{file}' ({str(e)})") def init_ldap_client(self): self.ldap_client = Server(self.settings["ldap"]["server"], get_info=ALL) def init_tacacs_client(self): self.tacacs_client = TACACSClient(self.settings["tacacs"]["address"], 49, environ.get("TACACS_PASSWORD")) def init_vault_client(self): self.vault_client = VaultClient() self.vault_client.token = environ.get("VAULT_TOKEN") if self.vault_client.sys.is_sealed( ) and self.settings["vault"]["unseal"]: keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)] self.vault_client.sys.submit_unseal_keys(filter(None, keys)) def init_syslog_server(self): self.syslog_server = SyslogServer(self.settings["syslog"]["address"], self.settings["syslog"]["port"]) self.syslog_server.start() def update_parameters(self, **kwargs): Session.query(models["parameters"]).one().update(**kwargs) self.__dict__.update(**kwargs) def delete_instance(self, instance_type, instance_id): return delete(instance_type, id=instance_id) def get(self, instance_type, id): return fetch(instance_type, id=id).serialized def get_properties(self, instance_type, id): return fetch(instance_type, id=id).get_properties() def get_all(self, instance_type): return [ instance.get_properties() for instance in fetch_all(instance_type) ] def update(self, type, **kwargs): try: must_be_new = kwargs.get("id") == "" for arg in ("name", "scoped_name"): if arg in kwargs: kwargs[arg] = kwargs[arg].strip() kwargs["last_modified"] = self.get_time() kwargs["creator"] = kwargs["user"] = getattr( current_user, "name", "admin") instance = factory(type, must_be_new=must_be_new, **kwargs) if kwargs.get("original"): fetch(type, id=kwargs["original"]).duplicate(clone=instance) Session.flush() return instance.serialized except Exception as exc: Session.rollback() if isinstance(exc, IntegrityError): return { "alert": f"There is already a {type} with the same parameters." } return {"alert": str(exc)} def log(self, severity, content): factory( "changelog", **{ "severity": severity, "content": content, "user": getattr(current_user, "name", "admin"), }, ) self.log_severity[severity](content) def count_models(self): return { "counters": { instance_type: count(instance_type) for instance_type in properties["dashboard"] }, "properties": { instance_type: Counter( str( getattr(instance, properties["dashboard"] [instance_type][0])) for instance in fetch_all(instance_type)) for instance_type in properties["dashboard"] }, } def compare(self, type, result1, result2): first = self.str_dict(getattr(fetch(type, id=result1), "result")).splitlines() second = self.str_dict(getattr(fetch(type, id=result2), "result")).splitlines() opcodes = SequenceMatcher(None, first, second).get_opcodes() return {"first": first, "second": second, "opcodes": opcodes} def build_filtering_constraints(self, obj_type, **kwargs): model, constraints = models[obj_type], [] for property in model_properties[obj_type]: value = kwargs["form"].get(property) if not value: continue filter = kwargs["form"].get(f"{property}_filter") if value in ("bool-true", "bool-false"): constraint = getattr(model, property) == (value == "bool-true") elif filter == "equality": constraint = getattr(model, property) == value elif not filter or filter == "inclusion" or DIALECT == "sqlite": constraint = getattr(model, property).contains(value) else: regex_operator = "regexp" if DIALECT == "mysql" else "~" constraint = getattr(model, property).op(regex_operator)(value) constraints.append(constraint) for related_model, relation_properties in relationships[ obj_type].items(): relation_ids = [ int(id) for id in kwargs["form"].get(related_model, []) ] filter = kwargs["form"].get(f"{related_model}_filter") if filter == "none": constraint = ~getattr(model, related_model).any() elif not relation_ids: continue elif relation_properties["list"]: constraint = getattr(model, related_model).any( models[relation_properties["model"]].id.in_(relation_ids)) if filter == "not_any": constraint = ~constraint else: constraint = or_( getattr(model, related_model).has(id=relation_id) for relation_id in relation_ids) constraints.append(constraint) return constraints def multiselect_filtering(self, type, **params): model = models[type] results = Session.query(model).filter( model.name.contains(params.get("term"))) return { "items": [{ "text": result.ui_name, "id": str(result.id) } for result in results.limit(10).offset((int(params["page"]) - 1) * 10).all()], "total_count": results.count(), } def table_filtering(self, table, **kwargs): model = models[table] ordering = getattr( getattr( model, kwargs["columns"][int(kwargs["order"][0]["column"])]["data"], None, ), kwargs["order"][0]["dir"], None, ) constraints = self.build_filtering_constraints(table, **kwargs) if table == "result": constraints.append( getattr( models["result"], "device" if kwargs["instance"]["type"] == "device" else "service", ).has(id=kwargs["instance"]["id"])) if kwargs.get("runtime"): constraints.append( models["result"].parent_runtime == kwargs["runtime"]) if table == "service": workflow_id = kwargs["form"].get("workflow-filtering") if workflow_id: constraints.append(models["service"].workflows.any( models["workflow"].id == int(workflow_id))) else: if kwargs["form"].get("parent-filtering", "true") == "true": constraints.append(~models["service"].workflows.any()) if table == "run": constraints.append( models["run"].parent_runtime == models["run"].runtime) result = Session.query(model).filter(and_(*constraints)) if ordering: result = result.order_by(ordering()) return { "draw": int(kwargs["draw"]), "recordsTotal": Session.query(func.count(model.id)).scalar(), "recordsFiltered": get_query_count(result), "data": [ obj.table_properties(**kwargs) for obj in result.limit( int(kwargs["length"])).offset(int(kwargs["start"])).all() ], } def allowed_file(self, name, allowed_modules): allowed_syntax = "." in name allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules return allowed_syntax and allowed_extension def get_time(self): return str(datetime.now()) def send_email( self, subject, content, recipients="", sender=None, filename=None, file_content=None, ): sender = sender or self.settings["mail"]["sender"] message = MIMEMultipart() message["From"] = sender message["To"] = recipients message["Date"] = formatdate(localtime=True) message["Subject"] = subject message.attach(MIMEText(content)) if filename: attached_file = MIMEApplication(file_content, Name=filename) attached_file[ "Content-Disposition"] = f'attachment; filename="{filename}"' message.attach(attached_file) server = SMTP(self.settings["mail"]["server"], self.settings["mail"]["port"]) if self.settings["mail"]["use_tls"]: server.starttls() password = environ.get("MAIL_PASSWORD", "") server.login(self.settings["mail"]["username"], password) server.sendmail(sender, recipients.split(","), message.as_string()) server.close() def str_dict(self, input, depth=0): tab = "\t" * depth if isinstance(input, list): result = "\n" for element in input: result += f"{tab}- {self.str_dict(element, depth + 1)}\n" return result elif isinstance(input, dict): result = "" for key, value in input.items(): result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}" return result else: return str(input) def strip_all(self, input): return input.translate(str.maketrans("", "", f"{punctuation} ")) def switch_menu(self, user_id): user = fetch("user", id=user_id) user.small_menu = not user.small_menu def update_database_configurations_from_git(self): for dir in scandir(self.path / "network_data"): device = fetch("device", allow_none=True, name=dir.name) if not device: continue with open(Path(dir.path) / "data.yml") as data: parameters = yaml.load(data) device.update(**{"dont_update_pools": True, **parameters}) for data in ("configuration", "operational_data"): filepath = Path(dir.path) / data if not filepath.exists(): continue with open(filepath) as file: setattr(device, data, file.read()) Session.commit() for pool in fetch_all("pool"): if pool.device_configuration or pool.device_operational_data: pool.compute_pool()
def init_syslog_server(self): self.syslog_server = SyslogServer(self.settings["syslog"]["address"], self.settings["syslog"]["port"]) self.syslog_server.start()
class BaseController: log_levels = ["debug", "info", "warning", "error", "critical"] rest_endpoints = [ "get_cluster_status", "get_git_content", "update_all_pools", "update_database_configurations_from_git", ] property_names = {} def __init__(self): self.pre_init() self.settings = settings self.rbac = rbac self.properties = properties self.database = database self.logging = logging self.load_custom_properties() self.path = Path.cwd() self.init_encryption() self.use_vault = settings["vault"]["use_vault"] if self.use_vault: self.init_vault_client() if settings["syslog"]["active"]: self.init_syslog_server() if settings["paths"]["custom_code"]: sys_path.append(settings["paths"]["custom_code"]) self.fetch_version() self.init_logs() self.init_redis() self.init_scheduler() self.init_connection_pools() self.post_init() def init_encryption(self): self.fernet_encryption = environ.get("FERNET_KEY") if self.fernet_encryption: fernet = Fernet(self.fernet_encryption) self.encrypt, self.decrypt = fernet.encrypt, fernet.decrypt else: self.encrypt, self.decrypt = b64encode, b64decode def encrypt_password(self, password): if isinstance(password, str): password = str.encode(password) return self.encrypt(password) def get_password(self, password): if not password: return if self.fernet_encryption and isinstance(password, str): password = str.encode(password) return str(self.decrypt(password), "utf-8") def initialize_database(self): self.init_services() db.base.metadata.create_all(bind=db.engine) configure_mappers() db.configure_application_events(self) self.init_forms() if not db.fetch("user", allow_none=True, name="admin"): self.create_admin_user() db.session.commit() self.migration_import( name=self.settings["app"].get("startup_migration", "default"), import_export_types=db.import_export_models, ) self.update_credentials() self.get_git_content() self.configure_server_id() self.reset_run_status() db.session.commit() def reset_run_status(self): for run in db.fetch("run", all_matches=True, allow_none=True, status="Running"): run.status = "Aborted (RELOAD)" run.service.status = "Idle" db.session.commit() def fetch_version(self): with open(self.path / "package.json") as package_file: self.version = load(package_file)["version"] def configure_server_id(self): db.factory( "server", **{ "name": str(getnode()), "description": "Localhost", "ip_address": "0.0.0.0", "status": "Up", }, ) def create_admin_user(self): admin = db.factory("user", name="admin", is_admin=True) if not admin.password: admin.update(password="******") def update_credentials(self): with open(self.path / "files" / "spreadsheets" / "usa.xls", "rb") as file: self.topology_import(file) def get_git_content(self): repo = self.settings["app"]["git_repository"] if not repo: return local_path = self.path / "network_data" try: if exists(local_path): Repo(local_path).remotes.origin.pull() else: local_path.mkdir(parents=True, exist_ok=True) Repo.clone_from(repo, local_path) except Exception as exc: self.log("error", f"Git pull failed ({str(exc)})") self.update_database_configurations_from_git() def load_custom_properties(self): for model, values in self.properties["custom"].items(): for property, property_dict in values.items(): pretty_name = property_dict["pretty_name"] self.property_names[property] = pretty_name model_properties[model].append(property) if property_dict.get("private"): db.private_properties.append(property) if model == "device" and property_dict.get("configuration"): self.configuration_properties[property] = pretty_name def init_logs(self): folder = self.path / "logs" folder.mkdir(parents=True, exist_ok=True) with open(self.path / "setup" / "logging.json", "r") as logging_config: logging_config = load(logging_config) dictConfig(logging_config) for logger, log_level in logging_config["external_loggers"].items(): info(f"Changing {logger} log level to '{log_level}'") log_level = getattr(import_module("logging"), log_level.upper()) getLogger(logger).setLevel(log_level) def init_connection_pools(self): self.request_session = RequestSession() retry = Retry(**self.settings["requests"]["retries"]) for protocol in ("http", "https"): self.request_session.mount( f"{protocol}://", HTTPAdapter(max_retries=retry, **self.settings["requests"]["pool"]), ) def init_forms(self): for file in (self.path / "eNMS" / "forms").glob("**/*.py"): spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) spec.loader.exec_module(module_from_spec(spec)) def init_redis(self): host = environ.get("REDIS_ADDR") self.redis_queue = (Redis( host=host, port=6379, db=0, charset="utf-8", decode_responses=True, socket_timeout=0.1, ) if host else None) def init_scheduler(self): self.scheduler_address = environ.get("SCHEDULER_ADDR") def init_services(self): path_services = [self.path / "eNMS" / "services"] load_examples = self.settings["app"].get( "startup_migration") == "examples" if self.settings["paths"]["custom_services"]: path_services.append( Path(self.settings["paths"]["custom_services"])) for path in path_services: for file in path.glob("**/*.py"): if "init" in str(file): continue if not load_examples and "examples" in str(file): continue info(f"Loading service: {file}") spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) try: spec.loader.exec_module(module_from_spec(spec)) except InvalidRequestError as exc: error( f"Error loading custom service '{file}' ({str(exc)})") def init_vault_client(self): url = environ.get("VAULT_ADDR", "http://127.0.0.1:8200") self.vault_client = VaultClient(url=url, token=environ.get("VAULT_TOKEN")) if self.vault_client.sys.is_sealed( ) and self.settings["vault"]["unseal_vault"]: keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)] self.vault_client.sys.submit_unseal_keys(filter(None, keys)) def init_syslog_server(self): self.syslog_server = SyslogServer(self.settings["syslog"]["address"], self.settings["syslog"]["port"]) self.syslog_server.start() def redis(self, operation, *args, **kwargs): try: return getattr(self.redis_queue, operation)(*args, **kwargs) except (ConnectionError, TimeoutError) as exc: self.log("error", f"Redis Queue Unreachable ({exc})", change_log=False) def log_queue(self, runtime, service, log=None, mode="add"): if self.redis_queue: key = f"{runtime}/{service}/logs" self.run_logs[runtime][int(service)] = None if mode == "add": log = self.redis("lpush", key, log) else: log = self.redis("lrange", key, 0, -1) if log: log = log[::-1] else: if mode == "add": return self.run_logs[runtime][int(service)].append(log) else: log = getattr(self.run_logs[runtime], mode)(int(service), []) return log def delete_instance(self, model, instance_id): return db.delete(model, id=instance_id) def get(self, model, id): return db.fetch(model, id=id).serialized def get_properties(self, model, id): return db.fetch(model, id=id).get_properties() def get_all(self, model): return [instance.get_properties() for instance in db.fetch_all(model)] def update(self, type, **kwargs): try: must_be_new = kwargs.get("id") == "" for arg in ("name", "scoped_name"): if arg in kwargs: kwargs[arg] = kwargs[arg].strip() kwargs["last_modified"] = self.get_time() kwargs["creator"] = kwargs["user"] = getattr( current_user, "name", "") instance = db.factory(type, must_be_new=must_be_new, **kwargs) if kwargs.get("copy"): db.fetch(type, id=kwargs["copy"]).duplicate(clone=instance) db.session.flush() return instance.serialized except db.rbac_error: return {"alert": "Error 403 - Operation not allowed."} except Exception as exc: db.session.rollback() if isinstance(exc, IntegrityError): return { "alert": f"There is already a {type} with the same parameters." } return {"alert": str(exc)} def log(self, severity, content, user=None, change_log=True, logger="root"): logger_settings = self.logging["loggers"].get(logger, {}) if logger: getattr(getLogger(logger), severity)(content) if change_log or logger and logger_settings.get("change_log"): db.factory( "changelog", **{ "severity": severity, "content": content, "user": user or getattr(current_user, "name", "admin"), }, ) return logger_settings def count_models(self): return { "counters": { model: db.query(model).count() for model in properties["dashboard"] }, "properties": { model: Counter( str(getattr(instance, properties["dashboard"][model][0])) for instance in db.fetch_all(model)) for model in properties["dashboard"] }, } def compare(self, type, device_name, v1, v2): if type in ("result", "device_result"): first = self.str_dict(getattr(db.fetch("result", id=v1), "result")) second = self.str_dict(getattr(db.fetch("result", id=v2), "result")) else: first = self.get_git_network_data(device_name, v1)[type] second = self.get_git_network_data(device_name, v2)[type] return "\n".join( unified_diff( first.splitlines(), second.splitlines(), fromfile="-", tofile="-", lineterm="", )) def build_filtering_constraints(self, model, **kwargs): table, constraints = models[model], [] for property in model_properties[model]: value = kwargs["form"].get(property) if not value: continue filter = kwargs["form"].get(f"{property}_filter") if value in ("bool-true", "bool-false"): constraint = getattr(table, property) == (value == "bool-true") elif filter == "equality": constraint = getattr(table, property) == value elif not filter or filter == "inclusion" or db.dialect == "sqlite": constraint = getattr(table, property).contains(value) else: compile(value) regex_operator = "regexp" if db.dialect == "mysql" else "~" constraint = getattr(table, property).op(regex_operator)(value) constraints.append(constraint) for related_model, relation_properties in relationships[model].items(): relation_ids = [ int(id) for id in kwargs["form"].get(related_model, []) ] filter = kwargs["form"].get(f"{related_model}_filter") if filter == "none": constraint = ~getattr(table, related_model).any() elif not relation_ids: continue elif relation_properties["list"]: constraint = (and_ if filter == "all" else or_)( getattr(table, related_model).any(id=relation_id) for relation_id in relation_ids) if filter == "not_any": constraint = ~constraint else: constraint = or_( getattr(table, related_model).has(id=relation_id) for relation_id in relation_ids) constraints.append(constraint) return constraints def multiselect_filtering(self, model, **params): table = models[model] results = db.query(model).filter( table.name.contains(params.get("term"))) return { "items": [{ "text": result.ui_name, "id": str(result.id) } for result in results.limit(10).offset((int(params["page"]) - 1) * 10).all()], "total_count": results.count(), } def filtering(self, model, **kwargs): table = models[model] ordering = getattr( getattr( table, kwargs["columns"][int(kwargs["order"][0]["column"])]["data"], None, ), kwargs["order"][0]["dir"], None, ) try: constraints = self.build_filtering_constraints(model, **kwargs) except regex_error: return {"error": "Invalid regular expression as search parameter."} constraints.extend(table.filtering_constraints(**kwargs)) query = db.query(model) total_records, query = query.count(), query.filter(and_(*constraints)) if ordering: query = query.order_by(ordering()) table_result = { "draw": int(kwargs["draw"]), "recordsTotal": total_records, "recordsFiltered": query.count(), "data": [ obj.table_properties(**kwargs) for obj in query.limit( int(kwargs["length"])).offset(int(kwargs["start"])).all() ], } if kwargs.get("export"): table_result["full_result"] = [ obj.table_properties(**kwargs) for obj in query.all() ] return table_result def allowed_file(self, name, allowed_modules): allowed_syntax = "." in name allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules return allowed_syntax and allowed_extension def get_time(self): return str(datetime.now()) def send_email( self, subject, content, recipients="", reply_to=None, sender=None, filename=None, file_content=None, ): sender = sender or self.settings["mail"]["sender"] message = MIMEMultipart() message["From"] = sender message["To"] = recipients message["Date"] = formatdate(localtime=True) message["Subject"] = subject message.add_header("reply-to", reply_to or self.settings["mail"]["reply_to"]) message.attach(MIMEText(content)) if filename: attached_file = MIMEApplication(file_content, Name=filename) attached_file[ "Content-Disposition"] = f'attachment; filename="{filename}"' message.attach(attached_file) server = SMTP(self.settings["mail"]["server"], self.settings["mail"]["port"]) if self.settings["mail"]["use_tls"]: server.starttls() password = environ.get("MAIL_PASSWORD", "") server.login(self.settings["mail"]["username"], password) server.sendmail(sender, recipients.split(","), message.as_string()) server.close() def contains_set(self, input): if isinstance(input, set): return True elif isinstance(input, list): return any(self.contains_set(x) for x in input) elif isinstance(input, dict): return any(self.contains_set(x) for x in input.values()) else: return False def str_dict(self, input, depth=0): tab = "\t" * depth if isinstance(input, list): result = "\n" for element in input: result += f"{tab}- {self.str_dict(element, depth + 1)}\n" return result elif isinstance(input, dict): result = "" for key, value in input.items(): result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}" return result else: return str(input) def strip_all(self, input): return input.translate(str.maketrans("", "", f"{punctuation} ")) def update_database_configurations_from_git(self): for dir in scandir(self.path / "network_data"): device = db.fetch("device", allow_none=True, name=dir.name) filepath = Path(dir.path) / "data.yml" if not device: continue if filepath.exists(): with open(Path(dir.path) / "data.yml") as data: parameters = yaml.load(data) device.update(**{"dont_update_pools": True, **parameters}) for data in self.configuration_properties: filepath = Path(dir.path) / data if not filepath.exists(): continue with open(filepath) as file: setattr(device, data, file.read()) db.session.commit() for pool in db.fetch_all("pool"): if any( getattr(pool, f"device_{property}") for property in self.configuration_properties): pool.compute_pool()
class BaseController: log_levels = ["debug", "info", "warning", "error", "critical"] rest_endpoints = [ "get_cluster_status", "get_git_content", "update_all_pools", "update_database_configurations_from_git", ] property_names = {} def __init__(self): self.pre_init() self.settings = settings self.properties = properties self.database = database self.logging = logging self.cli_command = self.detect_cli() self.load_custom_properties() self.load_configuration_properties() self.path = Path.cwd() self.init_rbac() self.init_encryption() self.use_vault = settings["vault"]["use_vault"] if self.use_vault: self.init_vault_client() if settings["syslog"]["active"]: self.init_syslog_server() if settings["paths"]["custom_code"]: sys_path.append(settings["paths"]["custom_code"]) self.fetch_version() self.init_logs() self.init_redis() self.init_scheduler() self.init_connection_pools() self.post_init() def init_encryption(self): self.fernet_encryption = getenv("FERNET_KEY") if self.fernet_encryption: fernet = Fernet(self.fernet_encryption) self.encrypt, self.decrypt = fernet.encrypt, fernet.decrypt else: self.encrypt, self.decrypt = b64encode, b64decode def detect_cli(self): try: return get_current_context().info_name == "flask" except RuntimeError: return False def encrypt_password(self, password): if isinstance(password, str): password = str.encode(password) return self.encrypt(password) def get_password(self, password): if not password: return if self.fernet_encryption and isinstance(password, str): password = str.encode(password) return str(self.decrypt(password), "utf-8") def initialize_database(self): self.init_plugins() db.private_properties_list = list(set(sum(db.private_properties.values(), []))) self.init_services() db.base.metadata.create_all(bind=db.engine) configure_mappers() db.configure_model_events(self) if self.cli_command: return self.init_forms() if not db.fetch("user", allow_none=True, name="admin"): self.create_admin_user() self.migration_import( name=self.settings["app"].get("startup_migration", "default"), import_export_types=db.import_export_models, ) self.update_credentials() self.get_git_content() self.configure_server_id() self.reset_run_status() db.session.commit() def reset_run_status(self): for run in db.fetch("run", all_matches=True, allow_none=True, status="Running"): run.status = "Aborted (RELOAD)" run.service.status = "Idle" db.session.commit() def fetch_version(self): with open(self.path / "package.json") as package_file: self.version = load(package_file)["version"] def configure_server_id(self): db.factory( "server", **{ "name": str(getnode()), "description": "Localhost", "ip_address": "0.0.0.0", "status": "Up", }, ) def create_admin_user(self): admin = db.factory("user", name="admin", is_admin=True, commit=True) if not admin.password: admin.update(password="******") def update_credentials(self): with open(self.path / "files" / "spreadsheets" / "usa.xls", "rb") as file: self.topology_import(file) def get_git_content(self): repo = self.settings["app"]["git_repository"] if not repo: return local_path = self.path / "network_data" try: if exists(local_path): Repo(local_path).remotes.origin.pull() else: local_path.mkdir(parents=True, exist_ok=True) Repo.clone_from(repo, local_path) except Exception as exc: self.log("error", f"Git pull failed ({str(exc)})") self.update_database_configurations_from_git() def load_custom_properties(self): for model, values in self.properties["custom"].items(): for property, property_dict in values.items(): pretty_name = property_dict["pretty_name"] self.property_names[property] = pretty_name model_properties[model].append(property) if property_dict.get("private"): if model not in db.private_properties: db.private_properties[model] = [] db.private_properties[model].append(property) if model == "device" and property_dict.get("configuration"): self.configuration_properties[property] = pretty_name def load_configuration_properties(self): for property, title in self.configuration_properties.items(): self.properties["filtering"]["device"].append(property) self.properties["tables"]["configuration"].insert( -1, { "data": property, "title": title, "search": "text", "width": "70%", "visible": property == "configuration", "orderable": False, }, ) for timestamp in self.configuration_timestamps: self.properties["tables"]["configuration"].insert( -1, { "data": f"last_{property}_{timestamp}", "title": f"Last {title} {timestamp.capitalize()}", "search": "text", "visible": False, }, ) def init_logs(self): folder = self.path / "logs" folder.mkdir(parents=True, exist_ok=True) with open(self.path / "setup" / "logging.json", "r") as logging_config: logging_config = load(logging_config) dictConfig(logging_config) for logger, log_level in logging_config["external_loggers"].items(): info(f"Changing {logger} log level to '{log_level}'") log_level = getattr(import_module("logging"), log_level.upper()) getLogger(logger).setLevel(log_level) def init_connection_pools(self): self.request_session = RequestSession() retry = Retry(**self.settings["requests"]["retries"]) for protocol in ("http", "https"): self.request_session.mount( f"{protocol}://", HTTPAdapter(max_retries=retry, **self.settings["requests"]["pool"]), ) def init_forms(self): for file in (self.path / "eNMS" / "forms").glob("**/*.py"): spec = spec_from_file_location(str(file).split("/")[-1][:-3], str(file)) spec.loader.exec_module(module_from_spec(spec)) def init_rbac(self): self.rbac = {"pages": [], **rbac} for _, category in rbac["menu"].items(): for page, page_values in category["pages"].items(): if page_values["rbac"] == "access": self.rbac["pages"].append(page) for subpage, subpage_values in page_values.get("subpages", {}).items(): if subpage_values["rbac"] == "access": self.rbac["pages"].append(subpage) def init_redis(self): host = getenv("REDIS_ADDR") self.redis_queue = ( Redis( host=host, port=6379, db=0, charset="utf-8", decode_responses=True, socket_timeout=0.1, ) if host else None ) def init_scheduler(self): self.scheduler_address = getenv("SCHEDULER_ADDR") def update_settings(self, old, new): for key, value in new.items(): if key not in old: old[key] = value else: old_value = old[key] if isinstance(old_value, list): old_value.extend(value) elif isinstance(old_value, dict): self.update_settings(old_value, value) else: old[key] = value return old def init_plugins(self): self.plugins = {} for plugin_path in Path(self.settings["app"]["plugin_path"]).iterdir(): if not Path(plugin_path / "settings.json").exists(): continue try: with open(plugin_path / "settings.json", "r") as file: settings = load(file) if not settings["active"]: continue self.plugins[plugin_path.stem] = { "settings": settings, "module": import_module(f"eNMS.plugins.{plugin_path.stem}"), } for setup_file in ("database", "properties", "rbac"): property = getattr(self, setup_file) self.update_settings(property, settings.get(setup_file, {})) except Exception as exc: error(f"Could not load plugin '{plugin_path.stem}' ({exc})") continue info(f"Loading plugin: {settings['name']}") def init_services(self): path_services = [self.path / "eNMS" / "services"] load_examples = self.settings["app"].get("startup_migration") == "examples" if self.settings["paths"]["custom_services"]: path_services.append(Path(self.settings["paths"]["custom_services"])) for path in path_services: for file in path.glob("**/*.py"): if "init" in str(file): continue if not load_examples and "examples" in str(file): continue info(f"Loading service: {file}") spec = spec_from_file_location(file.stem, str(file)) try: spec.loader.exec_module(module_from_spec(spec)) except InvalidRequestError as exc: error(f"Error loading custom service '{file}' ({str(exc)})") def init_vault_client(self): url = getenv("VAULT_ADDR", "http://127.0.0.1:8200") self.vault_client = VaultClient(url=url, token=getenv("VAULT_TOKEN")) if self.vault_client.sys.is_sealed() and self.settings["vault"]["unseal_vault"]: keys = [getenv(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)] self.vault_client.sys.submit_unseal_keys(filter(None, keys)) def init_syslog_server(self): self.syslog_server = SyslogServer( self.settings["syslog"]["address"], self.settings["syslog"]["port"] ) self.syslog_server.start() def redis(self, operation, *args, **kwargs): try: return getattr(self.redis_queue, operation)(*args, **kwargs) except (ConnectionError, TimeoutError) as exc: self.log("error", f"Redis Queue Unreachable ({exc})", change_log=False) def log_queue(self, runtime, service, log=None, mode="add", start_line=0): if self.redis_queue: key = f"{runtime}/{service}/logs" self.run_logs[runtime][int(service)] = None if mode == "add": log = self.redis("lpush", key, log) else: log = self.redis("lrange", key, 0, -1) if log: log = log[::-1][start_line:] else: if mode == "add": return self.run_logs[runtime][int(service)].append(log) else: full_log = getattr(self.run_logs[runtime], mode)(int(service), []) log = full_log[start_line:] return log def delete_instance(self, model, instance_id): return db.delete(model, id=instance_id) def get(self, model, id): return db.fetch(model, id=id).serialized def get_properties(self, model, id): return db.fetch(model, id=id).get_properties() def get_all(self, model): return [instance.get_properties() for instance in db.fetch_all(model)] def update(self, type, **kwargs): try: kwargs.update( { "last_modified": self.get_time(), "update_pools": True, "must_be_new": kwargs.get("id") == "", } ) for arg in ("name", "scoped_name"): if arg in kwargs: kwargs[arg] = kwargs[arg].strip() if kwargs["must_be_new"]: kwargs["creator"] = kwargs["user"] = getattr(current_user, "name", "") instance = db.factory(type, **kwargs) if kwargs.get("copy"): db.fetch(type, id=kwargs["copy"]).duplicate(clone=instance) db.session.flush() return instance.serialized except db.rbac_error: return {"alert": "Error 403 - Operation not allowed."} except Exception as exc: db.session.rollback() if isinstance(exc, IntegrityError): return {"alert": f"There is already a {type} with the same parameters."} self.log("error", format_exc()) return {"alert": str(exc)} def log(self, severity, content, user=None, change_log=True, logger="root"): logger_settings = self.logging["loggers"].get(logger, {}) if logger: getattr(getLogger(logger), severity)(content) if change_log or logger and logger_settings.get("change_log"): db.factory( "changelog", **{ "severity": severity, "content": content, "user": user or getattr(current_user, "name", ""), }, ) return logger_settings def compare(self, type, id, v1, v2, context_lines): if type in ("result", "device_result"): first = self.str_dict(getattr(db.fetch("result", id=v1), "result")) second = self.str_dict(getattr(db.fetch("result", id=v2), "result")) else: device = db.fetch("device", id=id) result1, v1 = self.get_git_network_data(device.name, v1) result2, v2 = self.get_git_network_data(device.name, v2) first, second = result1[type], result2[type] return "\n".join( unified_diff( first.splitlines(), second.splitlines(), fromfile=f"V1 ({v1})", tofile=f"V2 ({v2})", lineterm="", n=int(context_lines), ) ) def build_filtering_constraints(self, model, **kwargs): table, constraints = models[model], [] constraint_dict = {**kwargs["form"], **kwargs.get("constraints", {})} for property in model_properties[model]: value = constraint_dict.get(property) if not value: continue filter_value = constraint_dict.get(f"{property}_filter") if value in ("bool-true", "bool-false"): constraint = getattr(table, property) == (value == "bool-true") elif filter_value == "equality": constraint = getattr(table, property) == value elif ( not filter_value or filter_value == "inclusion" or db.dialect == "sqlite" ): constraint = getattr(table, property).contains( value, autoescape=isinstance(value, str) ) else: compile(value) regex_operator = "regexp" if db.dialect == "mysql" else "~" constraint = getattr(table, property).op(regex_operator)(value) constraints.append(constraint) return constraints def multiselect_filtering(self, model, **params): table = models[model] results = db.query(model).filter(table.name.contains(params.get("term"))) return { "items": [ {"text": result.ui_name, "id": str(result.id)} for result in results.limit(10) .offset((int(params["page"]) - 1) * 10) .all() ], "total_count": results.count(), } def build_relationship_constraints(self, query, model, **kwargs): table = models[model] constraint_dict = {**kwargs["form"], **kwargs.get("constraints", {})} for related_model, relation_properties in relationships[model].items(): relation_ids = [int(id) for id in constraint_dict.get(related_model, [])] if not relation_ids: continue related_table = aliased(models[relation_properties["model"]]) query = query.join(related_table, getattr(table, related_model)).filter( related_table.id.in_(relation_ids) ) return query def filtering(self, model, bulk=False, **kwargs): table, query = models[model], db.query(model) total_records = query.with_entities(table.id).count() try: constraints = self.build_filtering_constraints(model, **kwargs) except regex_error: return {"error": "Invalid regular expression as search parameter."} constraints.extend(table.filtering_constraints(**kwargs)) query = self.build_relationship_constraints(query, model, **kwargs) query = query.filter(and_(*constraints)) filtered_records = query.with_entities(table.id).count() if bulk: instances = query.all() return instances if bulk == "object" else [obj.id for obj in instances] data = kwargs["columns"][int(kwargs["order"][0]["column"])]["data"] ordering = getattr(getattr(table, data, None), kwargs["order"][0]["dir"], None) if ordering: query = query.order_by(ordering()) table_result = { "draw": int(kwargs["draw"]), "recordsTotal": total_records, "recordsFiltered": filtered_records, "data": [ obj.table_properties(**kwargs) for obj in query.limit(int(kwargs["length"])) .offset(int(kwargs["start"])) .all() ], } if kwargs.get("export"): table_result["full_result"] = [ obj.table_properties(**kwargs) for obj in query.all() ] if kwargs.get("clipboard"): table_result["full_result"] = ",".join(obj.name for obj in query.all()) return table_result def allowed_file(self, name, allowed_modules): allowed_syntax = "." in name allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules return allowed_syntax and allowed_extension def bulk_deletion(self, table, **kwargs): instances = self.filtering(table, bulk=True, form=kwargs) for instance_id in instances: db.delete(table, id=instance_id) return len(instances) def bulk_edit(self, table, **kwargs): instances = kwargs.pop("id").split("-") kwargs = { property: value for property, value in kwargs.items() if kwargs.get(f"bulk-edit-{property}") } for instance_id in instances: db.factory(table, id=instance_id, **kwargs) return len(instances) def get_time(self): return str(datetime.now()) def remove_instance(self, **kwargs): instance = db.fetch(kwargs["instance"]["type"], id=kwargs["instance"]["id"]) target = db.fetch(kwargs["relation"]["type"], id=kwargs["relation"]["id"]) if target.type == "pool" and not target.manually_defined: return {"alert": "Removing an object from a dynamic pool is an allowed."} getattr(target, kwargs["relation"]["relation"]["to"]).remove(instance) self.update_rbac(instance) def add_instances_in_bulk(self, **kwargs): target = db.fetch(kwargs["relation_type"], id=kwargs["relation_id"]) if target.type == "pool" and not target.manually_defined: return {"alert": "Adding objects to a dynamic pool is not allowed."} model, property = kwargs["model"], kwargs["property"] instances = set(db.objectify(model, kwargs["instances"])) if kwargs["names"]: for name in [instance.strip() for instance in kwargs["names"].split(",")]: instance = db.fetch(model, allow_none=True, name=name) if not instance: return {"alert": f"{model.capitalize()} '{name}' does not exist."} instances.add(instance) instances = instances - set(getattr(target, property)) for instance in instances: getattr(target, property).append(instance) target.last_modified = self.get_time() self.update_rbac(*instances) return {"number": len(instances), "target": target.base_properties} def bulk_removal( self, table, target_type, target_id, target_property, constraint_property, **kwargs, ): kwargs[constraint_property] = [target_id] target = db.fetch(target_type, id=target_id) if target.type == "pool" and not target.manually_defined: return {"alert": "Removing objects from a dynamic pool is an allowed."} instances = self.filtering(table, bulk="object", form=kwargs) for instance in instances: getattr(target, target_property).remove(instance) self.update_rbac(*instances) return len(instances) def update_rbac(self, *instances): for instance in instances: if instance.type != "user": continue instance.update_rbac() def send_email( self, subject, content, recipients="", reply_to=None, sender=None, filename=None, file_content=None, ): sender = sender or self.settings["mail"]["sender"] message = MIMEMultipart() message["From"] = sender message["To"] = recipients message["Date"] = formatdate(localtime=True) message["Subject"] = subject message.add_header("reply-to", reply_to or self.settings["mail"]["reply_to"]) message.attach(MIMEText(content)) if filename: attached_file = MIMEApplication(file_content, Name=filename) attached_file["Content-Disposition"] = f'attachment; filename="{filename}"' message.attach(attached_file) server = SMTP(self.settings["mail"]["server"], self.settings["mail"]["port"]) if self.settings["mail"]["use_tls"]: server.starttls() password = getenv("MAIL_PASSWORD", "") server.login(self.settings["mail"]["username"], password) server.sendmail(sender, recipients.split(","), message.as_string()) server.close() def contains_set(self, input): if isinstance(input, set): return True elif isinstance(input, list): return any(self.contains_set(x) for x in input) elif isinstance(input, dict): return any(self.contains_set(x) for x in input.values()) else: return False def str_dict(self, input, depth=0): tab = "\t" * depth if isinstance(input, list): result = "\n" for element in input: result += f"{tab}- {self.str_dict(element, depth + 1)}\n" return result elif isinstance(input, dict): result = "" for key, value in input.items(): result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}" return result else: return str(input) def strip_all(self, input): return input.translate(str.maketrans("", "", f"{punctuation} ")) def update_database_configurations_from_git(self): for dir in scandir(self.path / "network_data"): device = db.fetch("device", allow_none=True, name=dir.name) timestamp_path = Path(dir.path) / "timestamps.json" if not device: continue try: with open(timestamp_path) as file: timestamps = load(file) except Exception: timestamps = {} for property in self.configuration_properties: for timestamp, value in timestamps.get(property, {}).items(): setattr(device, f"last_{property}_{timestamp}", value) filepath = Path(dir.path) / property if not filepath.exists(): continue with open(filepath) as file: setattr(device, property, file.read()) db.session.commit() for pool in db.fetch_all("pool"): if any( getattr(pool, f"device_{property}") for property in self.configuration_properties ): pool.compute_pool()
class BaseController: cluster = int(environ.get("CLUSTER", False)) cluster_id = int(environ.get("CLUSTER_ID", True)) cluster_scan_subnet = environ.get("CLUSER_SCAN_SUBNET", "192.168.105.0/24") cluster_scan_protocol = environ.get("CLUSTER_SCAN_PROTOCOL", "http") cluster_scan_timeout = environ.get("CLUSTER_SCAN_TIMEOUT", 0.05) config_mode = environ.get("CONFIG_MODE", "Debug") custom_code_path = environ.get("CUSTOM_CODE_PATH") default_longitude = environ.get("DEFAULT_LONGITUDE", -96.0) default_latitude = environ.get("DEFAULT_LATITUDE", 33.0) default_zoom_level = environ.get("DEFAULT_ZOOM_LEVEL", 5) default_view = environ.get("DEFAULT_VIEW", "2D") default_marker = environ.get("DEFAULT_MARKER", "Image") documentation_url = environ.get("DOCUMENTATION_URL", "https://enms.readthedocs.io/en/latest/") create_examples = int(environ.get("CREATE_EXAMPLES", True)) custom_services_path = environ.get("CUSTOM_SERVICES_PATH") log_level = environ.get("LOG_LEVEL", "DEBUG") git_automation = environ.get("GIT_AUTOMATION") git_configurations = environ.get("GIT_CONFIGURATIONS") gotty_port_redirection = int(environ.get("GOTTY_PORT_REDIRECTION", False)) gotty_bypass_key_prompt = int(environ.get("GOTTY_BYPASS_KEY_PROMPT", False)) gotty_port = -1 gotty_start_port = int(environ.get("GOTTY_START_PORT", 9000)) gotty_end_port = int(environ.get("GOTTY_END_PORT", 9100)) ldap_server = environ.get("LDAP_SERVER") ldap_userdn = environ.get("LDAP_USERDN") ldap_basedn = environ.get("LDAP_BASEDN") ldap_admin_group = environ.get("LDAP_ADMIN_GROUP", "") mail_server = environ.get("MAIL_SERVER", "smtp.googlemail.com") mail_port = int(environ.get("MAIL_PORT", "587")) mail_use_tls = int(environ.get("MAIL_USE_TLS", True)) mail_username = environ.get("MAIL_USERNAME") mail_password = environ.get("MAIL_PASSWORD") mail_sender = environ.get("MAIL_SENDER", "*****@*****.**") mail_recipients = environ.get("MAIL_RECIPIENTS", "") mattermost_url = environ.get("MATTERMOST_URL") mattermost_channel = environ.get("MATTERMOST_CHANNEL") mattermost_verify_certificate = int( environ.get("MATTERMOST_VERIFY_CERTIFICATE", True)) opennms_login = environ.get("OPENNMS_LOGIN") opennms_devices = environ.get("OPENNMS_DEVICES", "") opennms_rest_api = environ.get("OPENNMS_REST_API") playbook_path = environ.get("PLAYBOOK_PATH") server_addr = environ.get("SERVER_ADDR", "http://SERVER_IP") slack_token = environ.get("SLACK_TOKEN") slack_channel = environ.get("SLACK_CHANNEL") syslog_addr = environ.get("SYSLOG_ADDR", "0.0.0.0") syslog_port = int(environ.get("SYSLOG_PORT", 514)) tacacs_addr = environ.get("TACACS_ADDR") tacacs_password = environ.get("TACACS_PASSWORD") unseal_vault = environ.get("UNSEAL_VAULT") use_ldap = int(environ.get("USE_LDAP", False)) use_syslog = int(environ.get("USE_SYSLOG", False)) use_tacacs = int(environ.get("USE_TACACS", False)) use_vault = int(environ.get("USE_VAULT", False)) vault_addr = environ.get("VAULT_ADDR") log_severity = {"error": error, "info": info, "warning": warning} free_access_pages = ["/", "/login"] valid_pages = [ "/administration", "/calendar/run", "/calendar/task", "/dashboard", "/login", "/table/changelog", "/table/configuration", "/table/device", "/table/event", "/table/pool", "/table/link", "/table/run", "/table/server", "/table/service", "/table/syslog", "/table/task", "/table/user", "/view/network", "/view/site", "/workflow_builder", ] valid_post_endpoints = [ "stop_workflow", "add_edge", "add_services_to_workflow", "calendar_init", "clear_results", "clear_configurations", "compare", "connection", "counters", "count_models", "create_label", "database_deletion", "delete_edge", "delete_instance", "delete_label", "delete_node", "duplicate_workflow", "export_service", "export_to_google_earth", "export_topology", "get", "get_all", "get_cluster_status", "get_configuration", "get_device_configuration", "get_device_logs", "get_exported_services", "get_git_content", "get_service_logs", "get_properties", "get_result", "get_runtimes", "get_view_topology", "get_workflow_state", "import_service", "import_topology", "migration_export", "migration_import", "multiselect_filtering", "query_netbox", "query_librenms", "query_opennms", "reset_status", "run_service", "save_parameters", "save_pool_objects", "save_positions", "scan_cluster", "scan_playbook_folder", "scheduler", "skip_services", "table_filtering", "task_action", "topology_import", "update", "update_parameters", "update_pool", "update_all_pools", "view_filtering", ] valid_rest_endpoints = [ "get_cluster_status", "get_git_content", "update_all_pools", "update_database_configurations_from_git", ] def __init__(self, path): self.path = path self.custom_properties = self.load_custom_properties() self.custom_config = self.load_custom_config() self.init_scheduler() if self.use_tacacs: self.init_tacacs_client() if self.use_ldap: self.init_ldap_client() if self.use_vault: self.init_vault_client() if self.use_syslog: self.init_syslog_server() if self.custom_code_path: sys_path.append(self.custom_code_path) self.create_google_earth_styles() self.fetch_version() self.init_logs() def configure_database(self): self.init_services() Base.metadata.create_all(bind=engine) configure_mappers() configure_events(self) self.init_forms() self.clean_database() if not fetch("user", allow_none=True, name="admin"): self.init_parameters() self.configure_server_id() self.create_admin_user() Session.commit() if self.create_examples: self.migration_import(name="examples", import_export_types=import_classes) self.update_credentials() else: self.migration_import(name="default", import_export_types=import_classes) self.get_git_content() Session.commit() def clean_database(self): for run in fetch("run", all_matches=True, allow_none=True, status="Running"): run.status = "Aborted (app reload)" Session.commit() def create_google_earth_styles(self): self.google_earth_styles = {} for icon in device_icons: point_style = Style() point_style.labelstyle.color = Color.blue path_icon = f"{self.path}/eNMS/static/images/2D/{icon}.gif" point_style.iconstyle.icon.href = path_icon self.google_earth_styles[icon] = point_style def fetch_version(self): with open(self.path / "package.json") as package_file: self.version = load(package_file)["version"] def init_parameters(self): parameters = Session.query(models["parameters"]).one_or_none() if not parameters: parameters = models["parameters"]() parameters.update( **{ property: getattr(self, property) for property in model_properties["parameters"] if hasattr(self, property) }) Session.add(parameters) Session.commit() else: for parameter in parameters.get_properties(): setattr(self, parameter, getattr(parameters, parameter)) def configure_server_id(self): factory( "server", **{ "name": str(getnode()), "description": "Localhost", "ip_address": "0.0.0.0", "status": "Up", }, ) def create_admin_user(self) -> None: admin = factory("user", **{"name": "admin"}) if not admin.password: admin.password = "******" def update_credentials(self): with open(self.path / "projects" / "spreadsheets" / "usa.xls", "rb") as file: self.topology_import(file) @property def config(self): parameters = Session.query(models["parameters"]).one_or_none() return parameters.get_properties() if parameters else {} def get_git_content(self): for repository_type in ("configurations", "automation"): repo = getattr(self, f"git_{repository_type}") if not repo: continue local_path = self.path / "git" / repository_type repo_contents_exist = False for entry in scandir(local_path): if entry.name == ".gitkeep": remove(entry) if entry.name == ".git": repo_contents_exist = True if repo_contents_exist: try: Repo(local_path).remotes.origin.pull() if repository_type == "configurations": self.update_database_configurations_from_git() except Exception as e: info( f"Cannot pull {repository_type} git repository ({str(e)})" ) else: try: Repo.clone_from(repo, local_path) if repository_type == "configurations": self.update_database_configurations_from_git() except Exception as e: info( f"Cannot clone {repository_type} git repository ({str(e)})" ) def load_custom_config(self): filepath = environ.get("PATH_CUSTOM_CONFIG") if not filepath: return {} else: with open(filepath, "r") as config: return load(config) def load_custom_properties(self): filepath = environ.get("PATH_CUSTOM_PROPERTIES") if not filepath: custom_properties = {} else: with open(filepath, "r") as properties: custom_properties = yaml.load(properties) property_names.update( {k: v["pretty_name"] for k, v in custom_properties.items()}) public_custom_properties = { k: v for k, v in custom_properties.items() if not v.get("private", False) } device_properties.extend(list(custom_properties)) pool_device_properties.extend(list(public_custom_properties)) for properties_table in table_properties, filtering_properties: properties_table["device"].extend(list(public_custom_properties)) device_diagram_properties.extend( list(p for p, v in custom_properties.items() if v["add_to_dashboard"])) private_properties.extend( list(p for p, v in custom_properties.items() if v.get("private", False))) return custom_properties def init_logs(self): basicConfig( level=getattr(import_module("logging"), self.log_level), format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%m-%d-%Y %H:%M:%S", handlers=[ RotatingFileHandler(self.path / "logs" / "enms.log", maxBytes=20_000_000, backupCount=10), StreamHandler(), ], ) def init_scheduler(self): self.scheduler = BackgroundScheduler({ "apscheduler.jobstores.default": { "type": "sqlalchemy", "url": "sqlite:///jobs.sqlite", }, "apscheduler.executors.default": { "class": "apscheduler.executors.pool:ThreadPoolExecutor", "max_workers": "50", }, "apscheduler.job_defaults.misfire_grace_time": "5", "apscheduler.job_defaults.coalesce": "true", "apscheduler.job_defaults.max_instances": "3", }) self.scheduler.start() def init_forms(self): for file in (self.path / "eNMS" / "forms").glob("**/*.py"): spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) spec.loader.exec_module(module_from_spec(spec)) def init_services(self): path_services = [self.path / "eNMS" / "services"] if self.custom_services_path: path_services.append(Path(self.custom_services_path)) for path in path_services: for file in path.glob("**/*.py"): if "init" in str(file): continue if not self.create_examples and "examples" in str(file): continue spec = spec_from_file_location( str(file).split("/")[-1][:-3], str(file)) try: spec.loader.exec_module(module_from_spec(spec)) except InvalidRequestError as e: error(f"Error loading custom service '{file}' ({str(e)})") def init_ldap_client(self): self.ldap_client = Server(self.ldap_server, get_info=ALL) def init_tacacs_client(self): self.tacacs_client = TACACSClient(self.tacacs_addr, 49, self.tacacs_password) def init_vault_client(self): self.vault_client = VaultClient() self.vault_client.url = self.vault_addr self.vault_client.token = environ.get("VAULT_TOKEN") if self.vault_client.sys.is_sealed() and self.unseal_vault: keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)] self.vault_client.sys.submit_unseal_keys(filter(None, keys)) def init_syslog_server(self): self.syslog_server = SyslogServer(self.syslog_addr, self.syslog_port) self.syslog_server.start() def update_parameters(self, **kwargs): Session.query(models["parameters"]).one().update(**kwargs) self.__dict__.update(**kwargs) def delete_instance(self, cls, instance_id): return delete(cls, id=instance_id) def get(self, cls, id): return fetch(cls, id=id).serialized def get_properties(self, cls, id): return fetch(cls, id=id).get_properties() def get_all(self, cls): return [instance.get_properties() for instance in fetch_all(cls)] def update(self, cls, **kwargs): try: must_be_new = kwargs.get("id") == "" kwargs["name"] = kwargs["name"].strip() kwargs["last_modified"] = self.get_time() kwargs["creator"] = getattr(current_user, "name", "admin") instance = factory(cls, must_be_new=must_be_new, **kwargs) Session.flush() return instance.serialized except Exception as exc: Session.rollback() if isinstance(exc, IntegrityError): return { "error": (f"There already is a {cls} with the same name") } return {"error": str(exc)} def log(self, severity, content): factory( "changelog", **{ "severity": severity, "content": content, "user": getattr(current_user, "name", "admin"), }, ) self.log_severity[severity](content) def count_models(self): return { "counters": {cls: count(cls) for cls in diagram_classes}, "properties": { cls: Counter( str(getattr(instance, type_to_diagram_properties[cls][0])) for instance in fetch_all(cls)) for cls in diagram_classes }, } def compare(self, type, result1, result2): first = self.str_dict(getattr(fetch(type, id=result1), type)).splitlines() second = self.str_dict(getattr(fetch(type, id=result2), type)).splitlines() opcodes = SequenceMatcher(None, first, second).get_opcodes() return {"first": first, "second": second, "opcodes": opcodes} def build_filtering_constraints(self, obj_type, kwargs): model, constraints = models[obj_type], [] for property in filtering_properties[obj_type]: value = kwargs.get(f"form[{property}]") if not value: continue filter = kwargs.get(f"form[{property}_filter]") if value in ("bool-true", "bool-false"): constraint = getattr(model, property) == (value == "bool-true") elif filter == "equality": constraint = getattr(model, property) == value elif filter == "inclusion" or DIALECT == "sqlite": constraint = getattr(model, property).contains(value) else: regex_operator = "regexp" if DIALECT == "mysql" else "~" constraint = getattr(model, property).op(regex_operator)(value) constraints.append(constraint) for related_model, relation_properties in relationships[ obj_type].items(): relation_ids = [ int(id) for id in kwargs.getlist(f"form[{related_model}][]") ] filter = kwargs.get(f"form[{related_model}_filter]") if filter == "none": constraint = ~getattr(model, related_model).any() elif not relation_ids: continue elif relation_properties["list"]: constraint = getattr(model, related_model).any( models[relation_properties["model"]].id.in_(relation_ids)) if filter == "not_any": constraint = ~constraint else: constraint = or_( getattr(model, related_model).has(id=relation_id) for relation_id in relation_ids) constraints.append(constraint) return constraints def multiselect_filtering(self, type, params): model = models[type] results = Session.query(model.id, model.name).filter( model.name.contains(params.get("term"))) return { "items": [{ "text": r.name, "id": str(r.id) } for r in results.limit(10).offset((int(params["page"]) - 1) * 10).all()], "total_count": results.count(), } def table_filtering(self, table, kwargs): model, properties = models[table], table_properties[table] operator = and_ if kwargs.get("form[operator]", "all") == "all" else or_ column_index = int(kwargs["order[0][column]"]) if column_index < len(properties): order_property = getattr(model, properties[column_index]) order_function = getattr(order_property, kwargs["order[0][dir]"], None) else: order_function = None constraints = self.build_filtering_constraints(table, kwargs) if table == "result": constraints.append( getattr( models["result"], "service" if "service" in kwargs["instance[type]"] else kwargs["instance[type]"], ).has(id=kwargs["instance[id]"])) if kwargs.get("service[runtime]"): constraints.append(models["result"].parent_runtime == kwargs.get("service[runtime]")) elif table == "configuration" and kwargs.get("instance[id]"): constraints.append( getattr(models[table], "device").has(id=kwargs["instance[id]"])) result = Session.query(model).filter(operator(*constraints)) if order_function: result = result.order_by(order_function()) return { "draw": int(kwargs["draw"]), "recordsTotal": Session.query(func.count(model.id)).scalar(), "recordsFiltered": get_query_count(result), "data": [[ getattr(obj, f"table_{property}", getattr(obj, property)) for property in properties ] + obj.generate_row(table) for obj in result.limit( int(kwargs["length"])).offset(int(kwargs["start"])).all()], } def allowed_file(self, name, allowed_modules): allowed_syntax = "." in name allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules return allowed_syntax and allowed_extension def get_time(self): return str(datetime.now()) def send_email( self, subject, content, sender=None, recipients=None, filename=None, file_content=None, ): sender = sender or self.mail_sender recipients = recipients or self.mail_recipients message = MIMEMultipart() message["From"] = sender message["To"] = recipients message["Date"] = formatdate(localtime=True) message["Subject"] = subject message.attach(MIMEText(content)) if filename: attached_file = MIMEApplication(file_content, Name=filename) attached_file[ "Content-Disposition"] = f'attachment; filename="{filename}"' message.attach(attached_file) server = SMTP(self.mail_server, self.mail_port) if self.mail_use_tls: server.starttls() server.login(self.mail_username, self.mail_password) server.sendmail(sender, recipients.split(","), message.as_string()) server.close() def str_dict(self, input, depth=0): tab = "\t" * depth if isinstance(input, list): result = "\n" for element in input: result += f"{tab}- {self.str_dict(element, depth + 1)}\n" return result elif isinstance(input, dict): result = "" for key, value in input.items(): result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}" return result else: return str(input) def strip_all(self, input): return input.translate(str.maketrans("", "", f"{punctuation} ")) def update_database_configurations_from_git(self): for dir in scandir(self.path / "git" / "configurations"): if dir.name == ".git": continue device = fetch("device", allow_none=True, name=dir.name) if device: with open(Path(dir.path) / "data.yml") as data: parameters = yaml.load(data) device.update(**{"dont_update_pools": True, **parameters}) config_file = Path(dir.path) / dir.name if not config_file.exists(): continue with open(config_file) as f: device.configuration = device.configurations[str( parameters["last_update"])] = f.read() Session.commit() for pool in fetch_all("pool"): if pool.device_configuration: pool.compute_pool()
def init_syslog_server(self): self.syslog_server = SyslogServer(self.syslog_addr, self.syslog_port) self.syslog_server.start()