Пример #1
0
class BaseController:

    log_severity = {"error": error, "info": info, "warning": warning}

    json_endpoints = [
        "multiselect_filtering",
        "save_settings",
        "table_filtering",
        "view_filtering",
    ]

    rest_endpoints = [
        "get_cluster_status",
        "get_git_content",
        "update_all_pools",
        "update_database_configurations_from_git",
    ]

    def __init__(self):
        self.settings = settings
        self.rbac = rbac
        self.properties = properties
        self.load_custom_properties()
        self.path = Path.cwd()
        self.init_scheduler()
        if settings["tacacs"]["active"]:
            self.init_tacacs_client()
        if settings["ldap"]["active"]:
            self.init_ldap_client()
        if settings["vault"]["active"]:
            self.init_vault_client()
        if settings["syslog"]["active"]:
            self.init_syslog_server()
        if settings["paths"]["custom_code"]:
            sys_path.append(settings["paths"]["custom_code"])
        self.fetch_version()
        self.init_logs()
        self.init_connection_pools()

    def configure_database(self):
        self.init_services()
        Base.metadata.create_all(bind=engine)
        configure_mappers()
        configure_events(self)
        self.init_forms()
        self.clean_database()
        if not fetch("user", allow_none=True, name="admin"):
            self.configure_server_id()
            self.create_admin_user()
            Session.commit()
            if self.settings["app"]["create_examples"]:
                self.migration_import(name="examples",
                                      import_export_types=import_classes)
                self.update_credentials()
            else:
                self.migration_import(name="default",
                                      import_export_types=import_classes)
            self.get_git_content()
            Session.commit()

    def clean_database(self):
        for run in fetch("run",
                         all_matches=True,
                         allow_none=True,
                         status="Running"):
            run.status = "Aborted (app reload)"
        Session.commit()

    def fetch_version(self):
        with open(self.path / "package.json") as package_file:
            self.version = load(package_file)["version"]

    def configure_server_id(self):
        factory(
            "server",
            **{
                "name": str(getnode()),
                "description": "Localhost",
                "ip_address": "0.0.0.0",
                "status": "Up",
            },
        )

    def create_admin_user(self) -> None:
        admin = factory("user", **{"name": "admin", "group": "Admin"})
        if not admin.password:
            admin.update(password="******")

    def update_credentials(self):
        with open(self.path / "files" / "spreadsheets" / "usa.xls",
                  "rb") as file:
            self.topology_import(file)

    def get_git_content(self):
        repo = self.settings["app"]["git_repository"]
        if not repo:
            return
        local_path = self.path / "network_data"
        try:
            if exists(local_path):
                Repo(local_path).remotes.origin.pull()
            else:
                local_path.mkdir(parents=True, exist_ok=True)
                Repo.clone_from(repo, local_path)
        except Exception as exc:
            self.log("error", f"Git pull failed ({str(exc)})")
        self.update_database_configurations_from_git()

    def load_custom_properties(self):
        for model, values in self.properties["custom"].items():
            property_names.update(
                {k: v["pretty_name"]
                 for k, v in values.items()})
            model_properties[model].extend(list(values))
            private_properties.extend(
                list(p for p, v in values.items() if v.get("private", False)))

    def init_logs(self):
        log_level = self.settings["app"]["log_level"].upper()
        folder = self.path / "logs"
        folder.mkdir(parents=True, exist_ok=True)
        basicConfig(
            level=getattr(import_module("logging"), log_level),
            format="%(asctime)s %(levelname)-8s %(message)s",
            datefmt="%m-%d-%Y %H:%M:%S",
            handlers=[
                RotatingFileHandler(folder / "enms.log",
                                    maxBytes=20_000_000,
                                    backupCount=10),
                StreamHandler(),
            ],
        )

    def init_connection_pools(self):
        self.request_session = RequestSession()
        retry = Retry(**self.settings["requests"]["retries"])
        for protocol in ("http", "https"):
            self.request_session.mount(
                f"{protocol}://",
                HTTPAdapter(
                    max_retries=retry,
                    **self.settings["requests"]["pool"],
                ),
            )

    def init_scheduler(self):
        self.scheduler = BackgroundScheduler({
            "apscheduler.jobstores.default": {
                "type": "sqlalchemy",
                "url": "sqlite:///jobs.sqlite",
            },
            "apscheduler.executors.default": {
                "class": "apscheduler.executors.pool:ThreadPoolExecutor",
                "max_workers": "50",
            },
            "apscheduler.job_defaults.misfire_grace_time":
            "5",
            "apscheduler.job_defaults.coalesce":
            "true",
            "apscheduler.job_defaults.max_instances":
            "3",
        })
        self.scheduler.start()

    def init_forms(self):
        for file in (self.path / "eNMS" / "forms").glob("**/*.py"):
            spec = spec_from_file_location(
                str(file).split("/")[-1][:-3], str(file))
            spec.loader.exec_module(module_from_spec(spec))

    def init_services(self):
        path_services = [self.path / "eNMS" / "services"]
        if self.settings["paths"]["custom_services"]:
            path_services.append(
                Path(self.settings["paths"]["custom_services"]))
        for path in path_services:
            for file in path.glob("**/*.py"):
                if "init" in str(file):
                    continue
                if not self.settings["app"][
                        "create_examples"] and "examples" in str(file):
                    continue
                info(f"Loading service: {file}")
                spec = spec_from_file_location(
                    str(file).split("/")[-1][:-3], str(file))
                try:
                    spec.loader.exec_module(module_from_spec(spec))
                except InvalidRequestError as e:
                    error(f"Error loading custom service '{file}' ({str(e)})")

    def init_ldap_client(self):
        self.ldap_client = Server(self.settings["ldap"]["server"],
                                  get_info=ALL)

    def init_tacacs_client(self):
        self.tacacs_client = TACACSClient(self.settings["tacacs"]["address"],
                                          49, environ.get("TACACS_PASSWORD"))

    def init_vault_client(self):
        self.vault_client = VaultClient()
        self.vault_client.token = environ.get("VAULT_TOKEN")
        if self.vault_client.sys.is_sealed(
        ) and self.settings["vault"]["unseal"]:
            keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)]
            self.vault_client.sys.submit_unseal_keys(filter(None, keys))

    def init_syslog_server(self):
        self.syslog_server = SyslogServer(self.settings["syslog"]["address"],
                                          self.settings["syslog"]["port"])
        self.syslog_server.start()

    def update_parameters(self, **kwargs):
        Session.query(models["parameters"]).one().update(**kwargs)
        self.__dict__.update(**kwargs)

    def delete_instance(self, instance_type, instance_id):
        return delete(instance_type, id=instance_id)

    def get(self, instance_type, id):
        return fetch(instance_type, id=id).serialized

    def get_properties(self, instance_type, id):
        return fetch(instance_type, id=id).get_properties()

    def get_all(self, instance_type):
        return [
            instance.get_properties() for instance in fetch_all(instance_type)
        ]

    def update(self, type, **kwargs):
        try:
            must_be_new = kwargs.get("id") == ""
            for arg in ("name", "scoped_name"):
                if arg in kwargs:
                    kwargs[arg] = kwargs[arg].strip()
            kwargs["last_modified"] = self.get_time()
            kwargs["creator"] = kwargs["user"] = getattr(
                current_user, "name", "admin")
            instance = factory(type, must_be_new=must_be_new, **kwargs)
            if kwargs.get("original"):
                fetch(type, id=kwargs["original"]).duplicate(clone=instance)
            Session.flush()
            return instance.serialized
        except Exception as exc:
            Session.rollback()
            if isinstance(exc, IntegrityError):
                return {
                    "alert":
                    f"There is already a {type} with the same parameters."
                }
            return {"alert": str(exc)}

    def log(self, severity, content):
        factory(
            "changelog",
            **{
                "severity": severity,
                "content": content,
                "user": getattr(current_user, "name", "admin"),
            },
        )
        self.log_severity[severity](content)

    def count_models(self):
        return {
            "counters": {
                instance_type: count(instance_type)
                for instance_type in properties["dashboard"]
            },
            "properties": {
                instance_type: Counter(
                    str(
                        getattr(instance, properties["dashboard"]
                                [instance_type][0]))
                    for instance in fetch_all(instance_type))
                for instance_type in properties["dashboard"]
            },
        }

    def compare(self, type, result1, result2):
        first = self.str_dict(getattr(fetch(type, id=result1),
                                      "result")).splitlines()
        second = self.str_dict(getattr(fetch(type, id=result2),
                                       "result")).splitlines()
        opcodes = SequenceMatcher(None, first, second).get_opcodes()
        return {"first": first, "second": second, "opcodes": opcodes}

    def build_filtering_constraints(self, obj_type, **kwargs):
        model, constraints = models[obj_type], []
        for property in model_properties[obj_type]:
            value = kwargs["form"].get(property)
            if not value:
                continue
            filter = kwargs["form"].get(f"{property}_filter")
            if value in ("bool-true", "bool-false"):
                constraint = getattr(model, property) == (value == "bool-true")
            elif filter == "equality":
                constraint = getattr(model, property) == value
            elif not filter or filter == "inclusion" or DIALECT == "sqlite":
                constraint = getattr(model, property).contains(value)
            else:
                regex_operator = "regexp" if DIALECT == "mysql" else "~"
                constraint = getattr(model, property).op(regex_operator)(value)
            constraints.append(constraint)
        for related_model, relation_properties in relationships[
                obj_type].items():
            relation_ids = [
                int(id) for id in kwargs["form"].get(related_model, [])
            ]
            filter = kwargs["form"].get(f"{related_model}_filter")
            if filter == "none":
                constraint = ~getattr(model, related_model).any()
            elif not relation_ids:
                continue
            elif relation_properties["list"]:
                constraint = getattr(model, related_model).any(
                    models[relation_properties["model"]].id.in_(relation_ids))
                if filter == "not_any":
                    constraint = ~constraint
            else:
                constraint = or_(
                    getattr(model, related_model).has(id=relation_id)
                    for relation_id in relation_ids)
            constraints.append(constraint)
        return constraints

    def multiselect_filtering(self, type, **params):
        model = models[type]
        results = Session.query(model).filter(
            model.name.contains(params.get("term")))
        return {
            "items":
            [{
                "text": result.ui_name,
                "id": str(result.id)
            }
             for result in results.limit(10).offset((int(params["page"]) - 1) *
                                                    10).all()],
            "total_count":
            results.count(),
        }

    def table_filtering(self, table, **kwargs):
        model = models[table]
        ordering = getattr(
            getattr(
                model,
                kwargs["columns"][int(kwargs["order"][0]["column"])]["data"],
                None,
            ),
            kwargs["order"][0]["dir"],
            None,
        )
        constraints = self.build_filtering_constraints(table, **kwargs)
        if table == "result":
            constraints.append(
                getattr(
                    models["result"],
                    "device"
                    if kwargs["instance"]["type"] == "device" else "service",
                ).has(id=kwargs["instance"]["id"]))
            if kwargs.get("runtime"):
                constraints.append(
                    models["result"].parent_runtime == kwargs["runtime"])
        if table == "service":
            workflow_id = kwargs["form"].get("workflow-filtering")
            if workflow_id:
                constraints.append(models["service"].workflows.any(
                    models["workflow"].id == int(workflow_id)))
            else:
                if kwargs["form"].get("parent-filtering", "true") == "true":
                    constraints.append(~models["service"].workflows.any())
        if table == "run":
            constraints.append(
                models["run"].parent_runtime == models["run"].runtime)
        result = Session.query(model).filter(and_(*constraints))
        if ordering:
            result = result.order_by(ordering())
        return {
            "draw":
            int(kwargs["draw"]),
            "recordsTotal":
            Session.query(func.count(model.id)).scalar(),
            "recordsFiltered":
            get_query_count(result),
            "data": [
                obj.table_properties(**kwargs) for obj in result.limit(
                    int(kwargs["length"])).offset(int(kwargs["start"])).all()
            ],
        }

    def allowed_file(self, name, allowed_modules):
        allowed_syntax = "." in name
        allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules
        return allowed_syntax and allowed_extension

    def get_time(self):
        return str(datetime.now())

    def send_email(
        self,
        subject,
        content,
        recipients="",
        sender=None,
        filename=None,
        file_content=None,
    ):
        sender = sender or self.settings["mail"]["sender"]
        message = MIMEMultipart()
        message["From"] = sender
        message["To"] = recipients
        message["Date"] = formatdate(localtime=True)
        message["Subject"] = subject
        message.attach(MIMEText(content))
        if filename:
            attached_file = MIMEApplication(file_content, Name=filename)
            attached_file[
                "Content-Disposition"] = f'attachment; filename="{filename}"'
            message.attach(attached_file)
        server = SMTP(self.settings["mail"]["server"],
                      self.settings["mail"]["port"])
        if self.settings["mail"]["use_tls"]:
            server.starttls()
            password = environ.get("MAIL_PASSWORD", "")
            server.login(self.settings["mail"]["username"], password)
        server.sendmail(sender, recipients.split(","), message.as_string())
        server.close()

    def str_dict(self, input, depth=0):
        tab = "\t" * depth
        if isinstance(input, list):
            result = "\n"
            for element in input:
                result += f"{tab}- {self.str_dict(element, depth + 1)}\n"
            return result
        elif isinstance(input, dict):
            result = ""
            for key, value in input.items():
                result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}"
            return result
        else:
            return str(input)

    def strip_all(self, input):
        return input.translate(str.maketrans("", "", f"{punctuation} "))

    def switch_menu(self, user_id):
        user = fetch("user", id=user_id)
        user.small_menu = not user.small_menu

    def update_database_configurations_from_git(self):
        for dir in scandir(self.path / "network_data"):
            device = fetch("device", allow_none=True, name=dir.name)
            if not device:
                continue
            with open(Path(dir.path) / "data.yml") as data:
                parameters = yaml.load(data)
                device.update(**{"dont_update_pools": True, **parameters})
            for data in ("configuration", "operational_data"):
                filepath = Path(dir.path) / data
                if not filepath.exists():
                    continue
                with open(filepath) as file:
                    setattr(device, data, file.read())
        Session.commit()
        for pool in fetch_all("pool"):
            if pool.device_configuration or pool.device_operational_data:
                pool.compute_pool()
Пример #2
0
 def init_syslog_server(self):
     self.syslog_server = SyslogServer(self.settings["syslog"]["address"],
                                       self.settings["syslog"]["port"])
     self.syslog_server.start()
Пример #3
0
class BaseController:

    log_levels = ["debug", "info", "warning", "error", "critical"]

    rest_endpoints = [
        "get_cluster_status",
        "get_git_content",
        "update_all_pools",
        "update_database_configurations_from_git",
    ]

    property_names = {}

    def __init__(self):
        self.pre_init()
        self.settings = settings
        self.rbac = rbac
        self.properties = properties
        self.database = database
        self.logging = logging
        self.load_custom_properties()
        self.path = Path.cwd()
        self.init_encryption()
        self.use_vault = settings["vault"]["use_vault"]
        if self.use_vault:
            self.init_vault_client()
        if settings["syslog"]["active"]:
            self.init_syslog_server()
        if settings["paths"]["custom_code"]:
            sys_path.append(settings["paths"]["custom_code"])
        self.fetch_version()
        self.init_logs()
        self.init_redis()
        self.init_scheduler()
        self.init_connection_pools()
        self.post_init()

    def init_encryption(self):
        self.fernet_encryption = environ.get("FERNET_KEY")
        if self.fernet_encryption:
            fernet = Fernet(self.fernet_encryption)
            self.encrypt, self.decrypt = fernet.encrypt, fernet.decrypt
        else:
            self.encrypt, self.decrypt = b64encode, b64decode

    def encrypt_password(self, password):
        if isinstance(password, str):
            password = str.encode(password)
        return self.encrypt(password)

    def get_password(self, password):
        if not password:
            return
        if self.fernet_encryption and isinstance(password, str):
            password = str.encode(password)
        return str(self.decrypt(password), "utf-8")

    def initialize_database(self):
        self.init_services()
        db.base.metadata.create_all(bind=db.engine)
        configure_mappers()
        db.configure_application_events(self)
        self.init_forms()
        if not db.fetch("user", allow_none=True, name="admin"):
            self.create_admin_user()
            db.session.commit()
            self.migration_import(
                name=self.settings["app"].get("startup_migration", "default"),
                import_export_types=db.import_export_models,
            )
            self.update_credentials()
            self.get_git_content()
        self.configure_server_id()
        self.reset_run_status()
        db.session.commit()

    def reset_run_status(self):
        for run in db.fetch("run",
                            all_matches=True,
                            allow_none=True,
                            status="Running"):
            run.status = "Aborted (RELOAD)"
            run.service.status = "Idle"
        db.session.commit()

    def fetch_version(self):
        with open(self.path / "package.json") as package_file:
            self.version = load(package_file)["version"]

    def configure_server_id(self):
        db.factory(
            "server",
            **{
                "name": str(getnode()),
                "description": "Localhost",
                "ip_address": "0.0.0.0",
                "status": "Up",
            },
        )

    def create_admin_user(self):
        admin = db.factory("user", name="admin", is_admin=True)
        if not admin.password:
            admin.update(password="******")

    def update_credentials(self):
        with open(self.path / "files" / "spreadsheets" / "usa.xls",
                  "rb") as file:
            self.topology_import(file)

    def get_git_content(self):
        repo = self.settings["app"]["git_repository"]
        if not repo:
            return
        local_path = self.path / "network_data"
        try:
            if exists(local_path):
                Repo(local_path).remotes.origin.pull()
            else:
                local_path.mkdir(parents=True, exist_ok=True)
                Repo.clone_from(repo, local_path)
        except Exception as exc:
            self.log("error", f"Git pull failed ({str(exc)})")
        self.update_database_configurations_from_git()

    def load_custom_properties(self):
        for model, values in self.properties["custom"].items():
            for property, property_dict in values.items():
                pretty_name = property_dict["pretty_name"]
                self.property_names[property] = pretty_name
                model_properties[model].append(property)
                if property_dict.get("private"):
                    db.private_properties.append(property)
                if model == "device" and property_dict.get("configuration"):
                    self.configuration_properties[property] = pretty_name

    def init_logs(self):
        folder = self.path / "logs"
        folder.mkdir(parents=True, exist_ok=True)
        with open(self.path / "setup" / "logging.json", "r") as logging_config:
            logging_config = load(logging_config)
        dictConfig(logging_config)
        for logger, log_level in logging_config["external_loggers"].items():
            info(f"Changing {logger} log level to '{log_level}'")
            log_level = getattr(import_module("logging"), log_level.upper())
            getLogger(logger).setLevel(log_level)

    def init_connection_pools(self):
        self.request_session = RequestSession()
        retry = Retry(**self.settings["requests"]["retries"])
        for protocol in ("http", "https"):
            self.request_session.mount(
                f"{protocol}://",
                HTTPAdapter(max_retries=retry,
                            **self.settings["requests"]["pool"]),
            )

    def init_forms(self):
        for file in (self.path / "eNMS" / "forms").glob("**/*.py"):
            spec = spec_from_file_location(
                str(file).split("/")[-1][:-3], str(file))
            spec.loader.exec_module(module_from_spec(spec))

    def init_redis(self):
        host = environ.get("REDIS_ADDR")
        self.redis_queue = (Redis(
            host=host,
            port=6379,
            db=0,
            charset="utf-8",
            decode_responses=True,
            socket_timeout=0.1,
        ) if host else None)

    def init_scheduler(self):
        self.scheduler_address = environ.get("SCHEDULER_ADDR")

    def init_services(self):
        path_services = [self.path / "eNMS" / "services"]
        load_examples = self.settings["app"].get(
            "startup_migration") == "examples"
        if self.settings["paths"]["custom_services"]:
            path_services.append(
                Path(self.settings["paths"]["custom_services"]))
        for path in path_services:
            for file in path.glob("**/*.py"):
                if "init" in str(file):
                    continue
                if not load_examples and "examples" in str(file):
                    continue
                info(f"Loading service: {file}")
                spec = spec_from_file_location(
                    str(file).split("/")[-1][:-3], str(file))
                try:
                    spec.loader.exec_module(module_from_spec(spec))
                except InvalidRequestError as exc:
                    error(
                        f"Error loading custom service '{file}' ({str(exc)})")

    def init_vault_client(self):
        url = environ.get("VAULT_ADDR", "http://127.0.0.1:8200")
        self.vault_client = VaultClient(url=url,
                                        token=environ.get("VAULT_TOKEN"))
        if self.vault_client.sys.is_sealed(
        ) and self.settings["vault"]["unseal_vault"]:
            keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)]
            self.vault_client.sys.submit_unseal_keys(filter(None, keys))

    def init_syslog_server(self):
        self.syslog_server = SyslogServer(self.settings["syslog"]["address"],
                                          self.settings["syslog"]["port"])
        self.syslog_server.start()

    def redis(self, operation, *args, **kwargs):
        try:
            return getattr(self.redis_queue, operation)(*args, **kwargs)
        except (ConnectionError, TimeoutError) as exc:
            self.log("error",
                     f"Redis Queue Unreachable ({exc})",
                     change_log=False)

    def log_queue(self, runtime, service, log=None, mode="add"):
        if self.redis_queue:
            key = f"{runtime}/{service}/logs"
            self.run_logs[runtime][int(service)] = None
            if mode == "add":
                log = self.redis("lpush", key, log)
            else:
                log = self.redis("lrange", key, 0, -1)
                if log:
                    log = log[::-1]
        else:
            if mode == "add":
                return self.run_logs[runtime][int(service)].append(log)
            else:
                log = getattr(self.run_logs[runtime], mode)(int(service), [])
        return log

    def delete_instance(self, model, instance_id):
        return db.delete(model, id=instance_id)

    def get(self, model, id):
        return db.fetch(model, id=id).serialized

    def get_properties(self, model, id):
        return db.fetch(model, id=id).get_properties()

    def get_all(self, model):
        return [instance.get_properties() for instance in db.fetch_all(model)]

    def update(self, type, **kwargs):
        try:
            must_be_new = kwargs.get("id") == ""
            for arg in ("name", "scoped_name"):
                if arg in kwargs:
                    kwargs[arg] = kwargs[arg].strip()
            kwargs["last_modified"] = self.get_time()
            kwargs["creator"] = kwargs["user"] = getattr(
                current_user, "name", "")
            instance = db.factory(type, must_be_new=must_be_new, **kwargs)
            if kwargs.get("copy"):
                db.fetch(type, id=kwargs["copy"]).duplicate(clone=instance)
            db.session.flush()
            return instance.serialized
        except db.rbac_error:
            return {"alert": "Error 403 - Operation not allowed."}
        except Exception as exc:
            db.session.rollback()
            if isinstance(exc, IntegrityError):
                return {
                    "alert":
                    f"There is already a {type} with the same parameters."
                }
            return {"alert": str(exc)}

    def log(self,
            severity,
            content,
            user=None,
            change_log=True,
            logger="root"):
        logger_settings = self.logging["loggers"].get(logger, {})
        if logger:
            getattr(getLogger(logger), severity)(content)
        if change_log or logger and logger_settings.get("change_log"):
            db.factory(
                "changelog",
                **{
                    "severity": severity,
                    "content": content,
                    "user": user or getattr(current_user, "name", "admin"),
                },
            )
        return logger_settings

    def count_models(self):
        return {
            "counters": {
                model: db.query(model).count()
                for model in properties["dashboard"]
            },
            "properties": {
                model: Counter(
                    str(getattr(instance, properties["dashboard"][model][0]))
                    for instance in db.fetch_all(model))
                for model in properties["dashboard"]
            },
        }

    def compare(self, type, device_name, v1, v2):
        if type in ("result", "device_result"):
            first = self.str_dict(getattr(db.fetch("result", id=v1), "result"))
            second = self.str_dict(getattr(db.fetch("result", id=v2),
                                           "result"))
        else:
            first = self.get_git_network_data(device_name, v1)[type]
            second = self.get_git_network_data(device_name, v2)[type]
        return "\n".join(
            unified_diff(
                first.splitlines(),
                second.splitlines(),
                fromfile="-",
                tofile="-",
                lineterm="",
            ))

    def build_filtering_constraints(self, model, **kwargs):
        table, constraints = models[model], []
        for property in model_properties[model]:
            value = kwargs["form"].get(property)
            if not value:
                continue
            filter = kwargs["form"].get(f"{property}_filter")
            if value in ("bool-true", "bool-false"):
                constraint = getattr(table, property) == (value == "bool-true")
            elif filter == "equality":
                constraint = getattr(table, property) == value
            elif not filter or filter == "inclusion" or db.dialect == "sqlite":
                constraint = getattr(table, property).contains(value)
            else:
                compile(value)
                regex_operator = "regexp" if db.dialect == "mysql" else "~"
                constraint = getattr(table, property).op(regex_operator)(value)
            constraints.append(constraint)
        for related_model, relation_properties in relationships[model].items():
            relation_ids = [
                int(id) for id in kwargs["form"].get(related_model, [])
            ]
            filter = kwargs["form"].get(f"{related_model}_filter")
            if filter == "none":
                constraint = ~getattr(table, related_model).any()
            elif not relation_ids:
                continue
            elif relation_properties["list"]:
                constraint = (and_ if filter == "all" else or_)(
                    getattr(table, related_model).any(id=relation_id)
                    for relation_id in relation_ids)
                if filter == "not_any":
                    constraint = ~constraint
            else:
                constraint = or_(
                    getattr(table, related_model).has(id=relation_id)
                    for relation_id in relation_ids)
            constraints.append(constraint)
        return constraints

    def multiselect_filtering(self, model, **params):
        table = models[model]
        results = db.query(model).filter(
            table.name.contains(params.get("term")))
        return {
            "items":
            [{
                "text": result.ui_name,
                "id": str(result.id)
            }
             for result in results.limit(10).offset((int(params["page"]) - 1) *
                                                    10).all()],
            "total_count":
            results.count(),
        }

    def filtering(self, model, **kwargs):
        table = models[model]
        ordering = getattr(
            getattr(
                table,
                kwargs["columns"][int(kwargs["order"][0]["column"])]["data"],
                None,
            ),
            kwargs["order"][0]["dir"],
            None,
        )
        try:
            constraints = self.build_filtering_constraints(model, **kwargs)
        except regex_error:
            return {"error": "Invalid regular expression as search parameter."}
        constraints.extend(table.filtering_constraints(**kwargs))
        query = db.query(model)
        total_records, query = query.count(), query.filter(and_(*constraints))
        if ordering:
            query = query.order_by(ordering())
        table_result = {
            "draw":
            int(kwargs["draw"]),
            "recordsTotal":
            total_records,
            "recordsFiltered":
            query.count(),
            "data": [
                obj.table_properties(**kwargs) for obj in query.limit(
                    int(kwargs["length"])).offset(int(kwargs["start"])).all()
            ],
        }
        if kwargs.get("export"):
            table_result["full_result"] = [
                obj.table_properties(**kwargs) for obj in query.all()
            ]
        return table_result

    def allowed_file(self, name, allowed_modules):
        allowed_syntax = "." in name
        allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules
        return allowed_syntax and allowed_extension

    def get_time(self):
        return str(datetime.now())

    def send_email(
        self,
        subject,
        content,
        recipients="",
        reply_to=None,
        sender=None,
        filename=None,
        file_content=None,
    ):
        sender = sender or self.settings["mail"]["sender"]
        message = MIMEMultipart()
        message["From"] = sender
        message["To"] = recipients
        message["Date"] = formatdate(localtime=True)
        message["Subject"] = subject
        message.add_header("reply-to", reply_to
                           or self.settings["mail"]["reply_to"])
        message.attach(MIMEText(content))
        if filename:
            attached_file = MIMEApplication(file_content, Name=filename)
            attached_file[
                "Content-Disposition"] = f'attachment; filename="{filename}"'
            message.attach(attached_file)
        server = SMTP(self.settings["mail"]["server"],
                      self.settings["mail"]["port"])
        if self.settings["mail"]["use_tls"]:
            server.starttls()
            password = environ.get("MAIL_PASSWORD", "")
            server.login(self.settings["mail"]["username"], password)
        server.sendmail(sender, recipients.split(","), message.as_string())
        server.close()

    def contains_set(self, input):
        if isinstance(input, set):
            return True
        elif isinstance(input, list):
            return any(self.contains_set(x) for x in input)
        elif isinstance(input, dict):
            return any(self.contains_set(x) for x in input.values())
        else:
            return False

    def str_dict(self, input, depth=0):
        tab = "\t" * depth
        if isinstance(input, list):
            result = "\n"
            for element in input:
                result += f"{tab}- {self.str_dict(element, depth + 1)}\n"
            return result
        elif isinstance(input, dict):
            result = ""
            for key, value in input.items():
                result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}"
            return result
        else:
            return str(input)

    def strip_all(self, input):
        return input.translate(str.maketrans("", "", f"{punctuation} "))

    def update_database_configurations_from_git(self):
        for dir in scandir(self.path / "network_data"):
            device = db.fetch("device", allow_none=True, name=dir.name)
            filepath = Path(dir.path) / "data.yml"
            if not device:
                continue
            if filepath.exists():
                with open(Path(dir.path) / "data.yml") as data:
                    parameters = yaml.load(data)
                    device.update(**{"dont_update_pools": True, **parameters})
            for data in self.configuration_properties:
                filepath = Path(dir.path) / data
                if not filepath.exists():
                    continue
                with open(filepath) as file:
                    setattr(device, data, file.read())
        db.session.commit()
        for pool in db.fetch_all("pool"):
            if any(
                    getattr(pool, f"device_{property}")
                    for property in self.configuration_properties):
                pool.compute_pool()
Пример #4
0
class BaseController:

    log_levels = ["debug", "info", "warning", "error", "critical"]

    rest_endpoints = [
        "get_cluster_status",
        "get_git_content",
        "update_all_pools",
        "update_database_configurations_from_git",
    ]

    property_names = {}

    def __init__(self):
        self.pre_init()
        self.settings = settings
        self.properties = properties
        self.database = database
        self.logging = logging
        self.cli_command = self.detect_cli()
        self.load_custom_properties()
        self.load_configuration_properties()
        self.path = Path.cwd()
        self.init_rbac()
        self.init_encryption()
        self.use_vault = settings["vault"]["use_vault"]
        if self.use_vault:
            self.init_vault_client()
        if settings["syslog"]["active"]:
            self.init_syslog_server()
        if settings["paths"]["custom_code"]:
            sys_path.append(settings["paths"]["custom_code"])
        self.fetch_version()
        self.init_logs()
        self.init_redis()
        self.init_scheduler()
        self.init_connection_pools()
        self.post_init()

    def init_encryption(self):
        self.fernet_encryption = getenv("FERNET_KEY")
        if self.fernet_encryption:
            fernet = Fernet(self.fernet_encryption)
            self.encrypt, self.decrypt = fernet.encrypt, fernet.decrypt
        else:
            self.encrypt, self.decrypt = b64encode, b64decode

    def detect_cli(self):
        try:
            return get_current_context().info_name == "flask"
        except RuntimeError:
            return False

    def encrypt_password(self, password):
        if isinstance(password, str):
            password = str.encode(password)
        return self.encrypt(password)

    def get_password(self, password):
        if not password:
            return
        if self.fernet_encryption and isinstance(password, str):
            password = str.encode(password)
        return str(self.decrypt(password), "utf-8")

    def initialize_database(self):
        self.init_plugins()
        db.private_properties_list = list(set(sum(db.private_properties.values(), [])))
        self.init_services()
        db.base.metadata.create_all(bind=db.engine)
        configure_mappers()
        db.configure_model_events(self)
        if self.cli_command:
            return
        self.init_forms()
        if not db.fetch("user", allow_none=True, name="admin"):
            self.create_admin_user()
            self.migration_import(
                name=self.settings["app"].get("startup_migration", "default"),
                import_export_types=db.import_export_models,
            )
            self.update_credentials()
            self.get_git_content()
        self.configure_server_id()
        self.reset_run_status()
        db.session.commit()

    def reset_run_status(self):
        for run in db.fetch("run", all_matches=True, allow_none=True, status="Running"):
            run.status = "Aborted (RELOAD)"
            run.service.status = "Idle"
        db.session.commit()

    def fetch_version(self):
        with open(self.path / "package.json") as package_file:
            self.version = load(package_file)["version"]

    def configure_server_id(self):
        db.factory(
            "server",
            **{
                "name": str(getnode()),
                "description": "Localhost",
                "ip_address": "0.0.0.0",
                "status": "Up",
            },
        )

    def create_admin_user(self):
        admin = db.factory("user", name="admin", is_admin=True, commit=True)
        if not admin.password:
            admin.update(password="******")

    def update_credentials(self):
        with open(self.path / "files" / "spreadsheets" / "usa.xls", "rb") as file:
            self.topology_import(file)

    def get_git_content(self):
        repo = self.settings["app"]["git_repository"]
        if not repo:
            return
        local_path = self.path / "network_data"
        try:
            if exists(local_path):
                Repo(local_path).remotes.origin.pull()
            else:
                local_path.mkdir(parents=True, exist_ok=True)
                Repo.clone_from(repo, local_path)
        except Exception as exc:
            self.log("error", f"Git pull failed ({str(exc)})")
        self.update_database_configurations_from_git()

    def load_custom_properties(self):
        for model, values in self.properties["custom"].items():
            for property, property_dict in values.items():
                pretty_name = property_dict["pretty_name"]
                self.property_names[property] = pretty_name
                model_properties[model].append(property)
                if property_dict.get("private"):
                    if model not in db.private_properties:
                        db.private_properties[model] = []
                    db.private_properties[model].append(property)
                if model == "device" and property_dict.get("configuration"):
                    self.configuration_properties[property] = pretty_name

    def load_configuration_properties(self):
        for property, title in self.configuration_properties.items():
            self.properties["filtering"]["device"].append(property)
            self.properties["tables"]["configuration"].insert(
                -1,
                {
                    "data": property,
                    "title": title,
                    "search": "text",
                    "width": "70%",
                    "visible": property == "configuration",
                    "orderable": False,
                },
            )
            for timestamp in self.configuration_timestamps:
                self.properties["tables"]["configuration"].insert(
                    -1,
                    {
                        "data": f"last_{property}_{timestamp}",
                        "title": f"Last {title} {timestamp.capitalize()}",
                        "search": "text",
                        "visible": False,
                    },
                )

    def init_logs(self):
        folder = self.path / "logs"
        folder.mkdir(parents=True, exist_ok=True)
        with open(self.path / "setup" / "logging.json", "r") as logging_config:
            logging_config = load(logging_config)
        dictConfig(logging_config)
        for logger, log_level in logging_config["external_loggers"].items():
            info(f"Changing {logger} log level to '{log_level}'")
            log_level = getattr(import_module("logging"), log_level.upper())
            getLogger(logger).setLevel(log_level)

    def init_connection_pools(self):
        self.request_session = RequestSession()
        retry = Retry(**self.settings["requests"]["retries"])
        for protocol in ("http", "https"):
            self.request_session.mount(
                f"{protocol}://",
                HTTPAdapter(max_retries=retry, **self.settings["requests"]["pool"]),
            )

    def init_forms(self):
        for file in (self.path / "eNMS" / "forms").glob("**/*.py"):
            spec = spec_from_file_location(str(file).split("/")[-1][:-3], str(file))
            spec.loader.exec_module(module_from_spec(spec))

    def init_rbac(self):
        self.rbac = {"pages": [], **rbac}
        for _, category in rbac["menu"].items():
            for page, page_values in category["pages"].items():
                if page_values["rbac"] == "access":
                    self.rbac["pages"].append(page)
                for subpage, subpage_values in page_values.get("subpages", {}).items():
                    if subpage_values["rbac"] == "access":
                        self.rbac["pages"].append(subpage)

    def init_redis(self):
        host = getenv("REDIS_ADDR")
        self.redis_queue = (
            Redis(
                host=host,
                port=6379,
                db=0,
                charset="utf-8",
                decode_responses=True,
                socket_timeout=0.1,
            )
            if host
            else None
        )

    def init_scheduler(self):
        self.scheduler_address = getenv("SCHEDULER_ADDR")

    def update_settings(self, old, new):
        for key, value in new.items():
            if key not in old:
                old[key] = value
            else:
                old_value = old[key]
                if isinstance(old_value, list):
                    old_value.extend(value)
                elif isinstance(old_value, dict):
                    self.update_settings(old_value, value)
                else:
                    old[key] = value

        return old

    def init_plugins(self):
        self.plugins = {}
        for plugin_path in Path(self.settings["app"]["plugin_path"]).iterdir():
            if not Path(plugin_path / "settings.json").exists():
                continue
            try:
                with open(plugin_path / "settings.json", "r") as file:
                    settings = load(file)
                if not settings["active"]:
                    continue
                self.plugins[plugin_path.stem] = {
                    "settings": settings,
                    "module": import_module(f"eNMS.plugins.{plugin_path.stem}"),
                }
                for setup_file in ("database", "properties", "rbac"):
                    property = getattr(self, setup_file)
                    self.update_settings(property, settings.get(setup_file, {}))
            except Exception as exc:
                error(f"Could not load plugin '{plugin_path.stem}' ({exc})")
                continue
            info(f"Loading plugin: {settings['name']}")

    def init_services(self):
        path_services = [self.path / "eNMS" / "services"]
        load_examples = self.settings["app"].get("startup_migration") == "examples"
        if self.settings["paths"]["custom_services"]:
            path_services.append(Path(self.settings["paths"]["custom_services"]))
        for path in path_services:
            for file in path.glob("**/*.py"):
                if "init" in str(file):
                    continue
                if not load_examples and "examples" in str(file):
                    continue
                info(f"Loading service: {file}")
                spec = spec_from_file_location(file.stem, str(file))
                try:
                    spec.loader.exec_module(module_from_spec(spec))
                except InvalidRequestError as exc:
                    error(f"Error loading custom service '{file}' ({str(exc)})")

    def init_vault_client(self):
        url = getenv("VAULT_ADDR", "http://127.0.0.1:8200")
        self.vault_client = VaultClient(url=url, token=getenv("VAULT_TOKEN"))
        if self.vault_client.sys.is_sealed() and self.settings["vault"]["unseal_vault"]:
            keys = [getenv(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)]
            self.vault_client.sys.submit_unseal_keys(filter(None, keys))

    def init_syslog_server(self):
        self.syslog_server = SyslogServer(
            self.settings["syslog"]["address"], self.settings["syslog"]["port"]
        )
        self.syslog_server.start()

    def redis(self, operation, *args, **kwargs):
        try:
            return getattr(self.redis_queue, operation)(*args, **kwargs)
        except (ConnectionError, TimeoutError) as exc:
            self.log("error", f"Redis Queue Unreachable ({exc})", change_log=False)

    def log_queue(self, runtime, service, log=None, mode="add", start_line=0):
        if self.redis_queue:
            key = f"{runtime}/{service}/logs"
            self.run_logs[runtime][int(service)] = None
            if mode == "add":
                log = self.redis("lpush", key, log)
            else:
                log = self.redis("lrange", key, 0, -1)
                if log:
                    log = log[::-1][start_line:]
        else:
            if mode == "add":
                return self.run_logs[runtime][int(service)].append(log)
            else:
                full_log = getattr(self.run_logs[runtime], mode)(int(service), [])
                log = full_log[start_line:]
        return log

    def delete_instance(self, model, instance_id):
        return db.delete(model, id=instance_id)

    def get(self, model, id):
        return db.fetch(model, id=id).serialized

    def get_properties(self, model, id):
        return db.fetch(model, id=id).get_properties()

    def get_all(self, model):
        return [instance.get_properties() for instance in db.fetch_all(model)]

    def update(self, type, **kwargs):
        try:
            kwargs.update(
                {
                    "last_modified": self.get_time(),
                    "update_pools": True,
                    "must_be_new": kwargs.get("id") == "",
                }
            )
            for arg in ("name", "scoped_name"):
                if arg in kwargs:
                    kwargs[arg] = kwargs[arg].strip()
            if kwargs["must_be_new"]:
                kwargs["creator"] = kwargs["user"] = getattr(current_user, "name", "")
            instance = db.factory(type, **kwargs)
            if kwargs.get("copy"):
                db.fetch(type, id=kwargs["copy"]).duplicate(clone=instance)
            db.session.flush()
            return instance.serialized
        except db.rbac_error:
            return {"alert": "Error 403 - Operation not allowed."}
        except Exception as exc:
            db.session.rollback()
            if isinstance(exc, IntegrityError):
                return {"alert": f"There is already a {type} with the same parameters."}
            self.log("error", format_exc())
            return {"alert": str(exc)}

    def log(self, severity, content, user=None, change_log=True, logger="root"):
        logger_settings = self.logging["loggers"].get(logger, {})
        if logger:
            getattr(getLogger(logger), severity)(content)
        if change_log or logger and logger_settings.get("change_log"):
            db.factory(
                "changelog",
                **{
                    "severity": severity,
                    "content": content,
                    "user": user or getattr(current_user, "name", ""),
                },
            )
        return logger_settings

    def compare(self, type, id, v1, v2, context_lines):
        if type in ("result", "device_result"):
            first = self.str_dict(getattr(db.fetch("result", id=v1), "result"))
            second = self.str_dict(getattr(db.fetch("result", id=v2), "result"))
        else:
            device = db.fetch("device", id=id)
            result1, v1 = self.get_git_network_data(device.name, v1)
            result2, v2 = self.get_git_network_data(device.name, v2)
            first, second = result1[type], result2[type]
        return "\n".join(
            unified_diff(
                first.splitlines(),
                second.splitlines(),
                fromfile=f"V1 ({v1})",
                tofile=f"V2 ({v2})",
                lineterm="",
                n=int(context_lines),
            )
        )

    def build_filtering_constraints(self, model, **kwargs):
        table, constraints = models[model], []
        constraint_dict = {**kwargs["form"], **kwargs.get("constraints", {})}
        for property in model_properties[model]:
            value = constraint_dict.get(property)
            if not value:
                continue
            filter_value = constraint_dict.get(f"{property}_filter")
            if value in ("bool-true", "bool-false"):
                constraint = getattr(table, property) == (value == "bool-true")
            elif filter_value == "equality":
                constraint = getattr(table, property) == value
            elif (
                not filter_value
                or filter_value == "inclusion"
                or db.dialect == "sqlite"
            ):
                constraint = getattr(table, property).contains(
                    value, autoescape=isinstance(value, str)
                )
            else:
                compile(value)
                regex_operator = "regexp" if db.dialect == "mysql" else "~"
                constraint = getattr(table, property).op(regex_operator)(value)
            constraints.append(constraint)
        return constraints

    def multiselect_filtering(self, model, **params):
        table = models[model]
        results = db.query(model).filter(table.name.contains(params.get("term")))
        return {
            "items": [
                {"text": result.ui_name, "id": str(result.id)}
                for result in results.limit(10)
                .offset((int(params["page"]) - 1) * 10)
                .all()
            ],
            "total_count": results.count(),
        }

    def build_relationship_constraints(self, query, model, **kwargs):
        table = models[model]
        constraint_dict = {**kwargs["form"], **kwargs.get("constraints", {})}
        for related_model, relation_properties in relationships[model].items():
            relation_ids = [int(id) for id in constraint_dict.get(related_model, [])]
            if not relation_ids:
                continue
            related_table = aliased(models[relation_properties["model"]])
            query = query.join(related_table, getattr(table, related_model)).filter(
                related_table.id.in_(relation_ids)
            )
        return query

    def filtering(self, model, bulk=False, **kwargs):
        table, query = models[model], db.query(model)
        total_records = query.with_entities(table.id).count()
        try:
            constraints = self.build_filtering_constraints(model, **kwargs)
        except regex_error:
            return {"error": "Invalid regular expression as search parameter."}
        constraints.extend(table.filtering_constraints(**kwargs))
        query = self.build_relationship_constraints(query, model, **kwargs)
        query = query.filter(and_(*constraints))
        filtered_records = query.with_entities(table.id).count()
        if bulk:
            instances = query.all()
            return instances if bulk == "object" else [obj.id for obj in instances]
        data = kwargs["columns"][int(kwargs["order"][0]["column"])]["data"]
        ordering = getattr(getattr(table, data, None), kwargs["order"][0]["dir"], None)
        if ordering:
            query = query.order_by(ordering())
        table_result = {
            "draw": int(kwargs["draw"]),
            "recordsTotal": total_records,
            "recordsFiltered": filtered_records,
            "data": [
                obj.table_properties(**kwargs)
                for obj in query.limit(int(kwargs["length"]))
                .offset(int(kwargs["start"]))
                .all()
            ],
        }
        if kwargs.get("export"):
            table_result["full_result"] = [
                obj.table_properties(**kwargs) for obj in query.all()
            ]
        if kwargs.get("clipboard"):
            table_result["full_result"] = ",".join(obj.name for obj in query.all())
        return table_result

    def allowed_file(self, name, allowed_modules):
        allowed_syntax = "." in name
        allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules
        return allowed_syntax and allowed_extension

    def bulk_deletion(self, table, **kwargs):
        instances = self.filtering(table, bulk=True, form=kwargs)
        for instance_id in instances:
            db.delete(table, id=instance_id)
        return len(instances)

    def bulk_edit(self, table, **kwargs):
        instances = kwargs.pop("id").split("-")
        kwargs = {
            property: value
            for property, value in kwargs.items()
            if kwargs.get(f"bulk-edit-{property}")
        }
        for instance_id in instances:
            db.factory(table, id=instance_id, **kwargs)
        return len(instances)

    def get_time(self):
        return str(datetime.now())

    def remove_instance(self, **kwargs):
        instance = db.fetch(kwargs["instance"]["type"], id=kwargs["instance"]["id"])
        target = db.fetch(kwargs["relation"]["type"], id=kwargs["relation"]["id"])
        if target.type == "pool" and not target.manually_defined:
            return {"alert": "Removing an object from a dynamic pool is an allowed."}
        getattr(target, kwargs["relation"]["relation"]["to"]).remove(instance)
        self.update_rbac(instance)

    def add_instances_in_bulk(self, **kwargs):
        target = db.fetch(kwargs["relation_type"], id=kwargs["relation_id"])
        if target.type == "pool" and not target.manually_defined:
            return {"alert": "Adding objects to a dynamic pool is not allowed."}
        model, property = kwargs["model"], kwargs["property"]
        instances = set(db.objectify(model, kwargs["instances"]))
        if kwargs["names"]:
            for name in [instance.strip() for instance in kwargs["names"].split(",")]:
                instance = db.fetch(model, allow_none=True, name=name)
                if not instance:
                    return {"alert": f"{model.capitalize()} '{name}' does not exist."}
                instances.add(instance)
        instances = instances - set(getattr(target, property))
        for instance in instances:
            getattr(target, property).append(instance)
        target.last_modified = self.get_time()
        self.update_rbac(*instances)
        return {"number": len(instances), "target": target.base_properties}

    def bulk_removal(
        self,
        table,
        target_type,
        target_id,
        target_property,
        constraint_property,
        **kwargs,
    ):
        kwargs[constraint_property] = [target_id]
        target = db.fetch(target_type, id=target_id)
        if target.type == "pool" and not target.manually_defined:
            return {"alert": "Removing objects from a dynamic pool is an allowed."}
        instances = self.filtering(table, bulk="object", form=kwargs)
        for instance in instances:
            getattr(target, target_property).remove(instance)
        self.update_rbac(*instances)
        return len(instances)

    def update_rbac(self, *instances):
        for instance in instances:
            if instance.type != "user":
                continue
            instance.update_rbac()

    def send_email(
        self,
        subject,
        content,
        recipients="",
        reply_to=None,
        sender=None,
        filename=None,
        file_content=None,
    ):
        sender = sender or self.settings["mail"]["sender"]
        message = MIMEMultipart()
        message["From"] = sender
        message["To"] = recipients
        message["Date"] = formatdate(localtime=True)
        message["Subject"] = subject
        message.add_header("reply-to", reply_to or self.settings["mail"]["reply_to"])
        message.attach(MIMEText(content))
        if filename:
            attached_file = MIMEApplication(file_content, Name=filename)
            attached_file["Content-Disposition"] = f'attachment; filename="{filename}"'
            message.attach(attached_file)
        server = SMTP(self.settings["mail"]["server"], self.settings["mail"]["port"])
        if self.settings["mail"]["use_tls"]:
            server.starttls()
            password = getenv("MAIL_PASSWORD", "")
            server.login(self.settings["mail"]["username"], password)
        server.sendmail(sender, recipients.split(","), message.as_string())
        server.close()

    def contains_set(self, input):
        if isinstance(input, set):
            return True
        elif isinstance(input, list):
            return any(self.contains_set(x) for x in input)
        elif isinstance(input, dict):
            return any(self.contains_set(x) for x in input.values())
        else:
            return False

    def str_dict(self, input, depth=0):
        tab = "\t" * depth
        if isinstance(input, list):
            result = "\n"
            for element in input:
                result += f"{tab}- {self.str_dict(element, depth + 1)}\n"
            return result
        elif isinstance(input, dict):
            result = ""
            for key, value in input.items():
                result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}"
            return result
        else:
            return str(input)

    def strip_all(self, input):
        return input.translate(str.maketrans("", "", f"{punctuation} "))

    def update_database_configurations_from_git(self):
        for dir in scandir(self.path / "network_data"):
            device = db.fetch("device", allow_none=True, name=dir.name)
            timestamp_path = Path(dir.path) / "timestamps.json"
            if not device:
                continue
            try:
                with open(timestamp_path) as file:
                    timestamps = load(file)
            except Exception:
                timestamps = {}
            for property in self.configuration_properties:
                for timestamp, value in timestamps.get(property, {}).items():
                    setattr(device, f"last_{property}_{timestamp}", value)
                filepath = Path(dir.path) / property
                if not filepath.exists():
                    continue
                with open(filepath) as file:
                    setattr(device, property, file.read())
        db.session.commit()
        for pool in db.fetch_all("pool"):
            if any(
                getattr(pool, f"device_{property}")
                for property in self.configuration_properties
            ):
                pool.compute_pool()
Пример #5
0
class BaseController:

    cluster = int(environ.get("CLUSTER", False))
    cluster_id = int(environ.get("CLUSTER_ID", True))
    cluster_scan_subnet = environ.get("CLUSER_SCAN_SUBNET", "192.168.105.0/24")
    cluster_scan_protocol = environ.get("CLUSTER_SCAN_PROTOCOL", "http")
    cluster_scan_timeout = environ.get("CLUSTER_SCAN_TIMEOUT", 0.05)
    config_mode = environ.get("CONFIG_MODE", "Debug")
    custom_code_path = environ.get("CUSTOM_CODE_PATH")
    default_longitude = environ.get("DEFAULT_LONGITUDE", -96.0)
    default_latitude = environ.get("DEFAULT_LATITUDE", 33.0)
    default_zoom_level = environ.get("DEFAULT_ZOOM_LEVEL", 5)
    default_view = environ.get("DEFAULT_VIEW", "2D")
    default_marker = environ.get("DEFAULT_MARKER", "Image")
    documentation_url = environ.get("DOCUMENTATION_URL",
                                    "https://enms.readthedocs.io/en/latest/")
    create_examples = int(environ.get("CREATE_EXAMPLES", True))
    custom_services_path = environ.get("CUSTOM_SERVICES_PATH")
    log_level = environ.get("LOG_LEVEL", "DEBUG")
    git_automation = environ.get("GIT_AUTOMATION")
    git_configurations = environ.get("GIT_CONFIGURATIONS")
    gotty_port_redirection = int(environ.get("GOTTY_PORT_REDIRECTION", False))
    gotty_bypass_key_prompt = int(environ.get("GOTTY_BYPASS_KEY_PROMPT",
                                              False))
    gotty_port = -1
    gotty_start_port = int(environ.get("GOTTY_START_PORT", 9000))
    gotty_end_port = int(environ.get("GOTTY_END_PORT", 9100))
    ldap_server = environ.get("LDAP_SERVER")
    ldap_userdn = environ.get("LDAP_USERDN")
    ldap_basedn = environ.get("LDAP_BASEDN")
    ldap_admin_group = environ.get("LDAP_ADMIN_GROUP", "")
    mail_server = environ.get("MAIL_SERVER", "smtp.googlemail.com")
    mail_port = int(environ.get("MAIL_PORT", "587"))
    mail_use_tls = int(environ.get("MAIL_USE_TLS", True))
    mail_username = environ.get("MAIL_USERNAME")
    mail_password = environ.get("MAIL_PASSWORD")
    mail_sender = environ.get("MAIL_SENDER", "*****@*****.**")
    mail_recipients = environ.get("MAIL_RECIPIENTS", "")
    mattermost_url = environ.get("MATTERMOST_URL")
    mattermost_channel = environ.get("MATTERMOST_CHANNEL")
    mattermost_verify_certificate = int(
        environ.get("MATTERMOST_VERIFY_CERTIFICATE", True))
    opennms_login = environ.get("OPENNMS_LOGIN")
    opennms_devices = environ.get("OPENNMS_DEVICES", "")
    opennms_rest_api = environ.get("OPENNMS_REST_API")
    playbook_path = environ.get("PLAYBOOK_PATH")
    server_addr = environ.get("SERVER_ADDR", "http://SERVER_IP")
    slack_token = environ.get("SLACK_TOKEN")
    slack_channel = environ.get("SLACK_CHANNEL")
    syslog_addr = environ.get("SYSLOG_ADDR", "0.0.0.0")
    syslog_port = int(environ.get("SYSLOG_PORT", 514))
    tacacs_addr = environ.get("TACACS_ADDR")
    tacacs_password = environ.get("TACACS_PASSWORD")
    unseal_vault = environ.get("UNSEAL_VAULT")
    use_ldap = int(environ.get("USE_LDAP", False))
    use_syslog = int(environ.get("USE_SYSLOG", False))
    use_tacacs = int(environ.get("USE_TACACS", False))
    use_vault = int(environ.get("USE_VAULT", False))
    vault_addr = environ.get("VAULT_ADDR")

    log_severity = {"error": error, "info": info, "warning": warning}

    free_access_pages = ["/", "/login"]

    valid_pages = [
        "/administration",
        "/calendar/run",
        "/calendar/task",
        "/dashboard",
        "/login",
        "/table/changelog",
        "/table/configuration",
        "/table/device",
        "/table/event",
        "/table/pool",
        "/table/link",
        "/table/run",
        "/table/server",
        "/table/service",
        "/table/syslog",
        "/table/task",
        "/table/user",
        "/view/network",
        "/view/site",
        "/workflow_builder",
    ]

    valid_post_endpoints = [
        "stop_workflow",
        "add_edge",
        "add_services_to_workflow",
        "calendar_init",
        "clear_results",
        "clear_configurations",
        "compare",
        "connection",
        "counters",
        "count_models",
        "create_label",
        "database_deletion",
        "delete_edge",
        "delete_instance",
        "delete_label",
        "delete_node",
        "duplicate_workflow",
        "export_service",
        "export_to_google_earth",
        "export_topology",
        "get",
        "get_all",
        "get_cluster_status",
        "get_configuration",
        "get_device_configuration",
        "get_device_logs",
        "get_exported_services",
        "get_git_content",
        "get_service_logs",
        "get_properties",
        "get_result",
        "get_runtimes",
        "get_view_topology",
        "get_workflow_state",
        "import_service",
        "import_topology",
        "migration_export",
        "migration_import",
        "multiselect_filtering",
        "query_netbox",
        "query_librenms",
        "query_opennms",
        "reset_status",
        "run_service",
        "save_parameters",
        "save_pool_objects",
        "save_positions",
        "scan_cluster",
        "scan_playbook_folder",
        "scheduler",
        "skip_services",
        "table_filtering",
        "task_action",
        "topology_import",
        "update",
        "update_parameters",
        "update_pool",
        "update_all_pools",
        "view_filtering",
    ]

    valid_rest_endpoints = [
        "get_cluster_status",
        "get_git_content",
        "update_all_pools",
        "update_database_configurations_from_git",
    ]

    def __init__(self, path):
        self.path = path
        self.custom_properties = self.load_custom_properties()
        self.custom_config = self.load_custom_config()
        self.init_scheduler()
        if self.use_tacacs:
            self.init_tacacs_client()
        if self.use_ldap:
            self.init_ldap_client()
        if self.use_vault:
            self.init_vault_client()
        if self.use_syslog:
            self.init_syslog_server()
        if self.custom_code_path:
            sys_path.append(self.custom_code_path)
        self.create_google_earth_styles()
        self.fetch_version()
        self.init_logs()

    def configure_database(self):
        self.init_services()
        Base.metadata.create_all(bind=engine)
        configure_mappers()
        configure_events(self)
        self.init_forms()
        self.clean_database()
        if not fetch("user", allow_none=True, name="admin"):
            self.init_parameters()
            self.configure_server_id()
            self.create_admin_user()
            Session.commit()
            if self.create_examples:
                self.migration_import(name="examples",
                                      import_export_types=import_classes)
                self.update_credentials()
            else:
                self.migration_import(name="default",
                                      import_export_types=import_classes)
            self.get_git_content()
            Session.commit()

    def clean_database(self):
        for run in fetch("run",
                         all_matches=True,
                         allow_none=True,
                         status="Running"):
            run.status = "Aborted (app reload)"
        Session.commit()

    def create_google_earth_styles(self):
        self.google_earth_styles = {}
        for icon in device_icons:
            point_style = Style()
            point_style.labelstyle.color = Color.blue
            path_icon = f"{self.path}/eNMS/static/images/2D/{icon}.gif"
            point_style.iconstyle.icon.href = path_icon
            self.google_earth_styles[icon] = point_style

    def fetch_version(self):
        with open(self.path / "package.json") as package_file:
            self.version = load(package_file)["version"]

    def init_parameters(self):
        parameters = Session.query(models["parameters"]).one_or_none()
        if not parameters:
            parameters = models["parameters"]()
            parameters.update(
                **{
                    property: getattr(self, property)
                    for property in model_properties["parameters"]
                    if hasattr(self, property)
                })
            Session.add(parameters)
            Session.commit()
        else:
            for parameter in parameters.get_properties():
                setattr(self, parameter, getattr(parameters, parameter))

    def configure_server_id(self):
        factory(
            "server",
            **{
                "name": str(getnode()),
                "description": "Localhost",
                "ip_address": "0.0.0.0",
                "status": "Up",
            },
        )

    def create_admin_user(self) -> None:
        admin = factory("user", **{"name": "admin"})
        if not admin.password:
            admin.password = "******"

    def update_credentials(self):
        with open(self.path / "projects" / "spreadsheets" / "usa.xls",
                  "rb") as file:
            self.topology_import(file)

    @property
    def config(self):
        parameters = Session.query(models["parameters"]).one_or_none()
        return parameters.get_properties() if parameters else {}

    def get_git_content(self):
        for repository_type in ("configurations", "automation"):
            repo = getattr(self, f"git_{repository_type}")
            if not repo:
                continue
            local_path = self.path / "git" / repository_type
            repo_contents_exist = False
            for entry in scandir(local_path):
                if entry.name == ".gitkeep":
                    remove(entry)
                if entry.name == ".git":
                    repo_contents_exist = True
            if repo_contents_exist:
                try:
                    Repo(local_path).remotes.origin.pull()
                    if repository_type == "configurations":
                        self.update_database_configurations_from_git()
                except Exception as e:
                    info(
                        f"Cannot pull {repository_type} git repository ({str(e)})"
                    )
            else:
                try:
                    Repo.clone_from(repo, local_path)
                    if repository_type == "configurations":
                        self.update_database_configurations_from_git()
                except Exception as e:
                    info(
                        f"Cannot clone {repository_type} git repository ({str(e)})"
                    )

    def load_custom_config(self):
        filepath = environ.get("PATH_CUSTOM_CONFIG")
        if not filepath:
            return {}
        else:
            with open(filepath, "r") as config:
                return load(config)

    def load_custom_properties(self):
        filepath = environ.get("PATH_CUSTOM_PROPERTIES")
        if not filepath:
            custom_properties = {}
        else:
            with open(filepath, "r") as properties:
                custom_properties = yaml.load(properties)
        property_names.update(
            {k: v["pretty_name"]
             for k, v in custom_properties.items()})
        public_custom_properties = {
            k: v
            for k, v in custom_properties.items()
            if not v.get("private", False)
        }
        device_properties.extend(list(custom_properties))
        pool_device_properties.extend(list(public_custom_properties))
        for properties_table in table_properties, filtering_properties:
            properties_table["device"].extend(list(public_custom_properties))
        device_diagram_properties.extend(
            list(p for p, v in custom_properties.items()
                 if v["add_to_dashboard"]))
        private_properties.extend(
            list(p for p, v in custom_properties.items()
                 if v.get("private", False)))
        return custom_properties

    def init_logs(self):
        basicConfig(
            level=getattr(import_module("logging"), self.log_level),
            format="%(asctime)s %(levelname)-8s %(message)s",
            datefmt="%m-%d-%Y %H:%M:%S",
            handlers=[
                RotatingFileHandler(self.path / "logs" / "enms.log",
                                    maxBytes=20_000_000,
                                    backupCount=10),
                StreamHandler(),
            ],
        )

    def init_scheduler(self):
        self.scheduler = BackgroundScheduler({
            "apscheduler.jobstores.default": {
                "type": "sqlalchemy",
                "url": "sqlite:///jobs.sqlite",
            },
            "apscheduler.executors.default": {
                "class": "apscheduler.executors.pool:ThreadPoolExecutor",
                "max_workers": "50",
            },
            "apscheduler.job_defaults.misfire_grace_time":
            "5",
            "apscheduler.job_defaults.coalesce":
            "true",
            "apscheduler.job_defaults.max_instances":
            "3",
        })

        self.scheduler.start()

    def init_forms(self):
        for file in (self.path / "eNMS" / "forms").glob("**/*.py"):
            spec = spec_from_file_location(
                str(file).split("/")[-1][:-3], str(file))
            spec.loader.exec_module(module_from_spec(spec))

    def init_services(self):
        path_services = [self.path / "eNMS" / "services"]
        if self.custom_services_path:
            path_services.append(Path(self.custom_services_path))
        for path in path_services:
            for file in path.glob("**/*.py"):
                if "init" in str(file):
                    continue
                if not self.create_examples and "examples" in str(file):
                    continue
                spec = spec_from_file_location(
                    str(file).split("/")[-1][:-3], str(file))
                try:
                    spec.loader.exec_module(module_from_spec(spec))
                except InvalidRequestError as e:
                    error(f"Error loading custom service '{file}' ({str(e)})")

    def init_ldap_client(self):
        self.ldap_client = Server(self.ldap_server, get_info=ALL)

    def init_tacacs_client(self):
        self.tacacs_client = TACACSClient(self.tacacs_addr, 49,
                                          self.tacacs_password)

    def init_vault_client(self):
        self.vault_client = VaultClient()
        self.vault_client.url = self.vault_addr
        self.vault_client.token = environ.get("VAULT_TOKEN")
        if self.vault_client.sys.is_sealed() and self.unseal_vault:
            keys = [environ.get(f"UNSEAL_VAULT_KEY{i}") for i in range(1, 6)]
            self.vault_client.sys.submit_unseal_keys(filter(None, keys))

    def init_syslog_server(self):
        self.syslog_server = SyslogServer(self.syslog_addr, self.syslog_port)
        self.syslog_server.start()

    def update_parameters(self, **kwargs):
        Session.query(models["parameters"]).one().update(**kwargs)
        self.__dict__.update(**kwargs)

    def delete_instance(self, cls, instance_id):
        return delete(cls, id=instance_id)

    def get(self, cls, id):
        return fetch(cls, id=id).serialized

    def get_properties(self, cls, id):
        return fetch(cls, id=id).get_properties()

    def get_all(self, cls):
        return [instance.get_properties() for instance in fetch_all(cls)]

    def update(self, cls, **kwargs):
        try:
            must_be_new = kwargs.get("id") == ""
            kwargs["name"] = kwargs["name"].strip()
            kwargs["last_modified"] = self.get_time()
            kwargs["creator"] = getattr(current_user, "name", "admin")
            instance = factory(cls, must_be_new=must_be_new, **kwargs)
            Session.flush()
            return instance.serialized
        except Exception as exc:
            Session.rollback()
            if isinstance(exc, IntegrityError):
                return {
                    "error": (f"There already is a {cls} with the same name")
                }
            return {"error": str(exc)}

    def log(self, severity, content):
        factory(
            "changelog",
            **{
                "severity": severity,
                "content": content,
                "user": getattr(current_user, "name", "admin"),
            },
        )
        self.log_severity[severity](content)

    def count_models(self):
        return {
            "counters": {cls: count(cls)
                         for cls in diagram_classes},
            "properties": {
                cls: Counter(
                    str(getattr(instance, type_to_diagram_properties[cls][0]))
                    for instance in fetch_all(cls))
                for cls in diagram_classes
            },
        }

    def compare(self, type, result1, result2):
        first = self.str_dict(getattr(fetch(type, id=result1),
                                      type)).splitlines()
        second = self.str_dict(getattr(fetch(type, id=result2),
                                       type)).splitlines()
        opcodes = SequenceMatcher(None, first, second).get_opcodes()
        return {"first": first, "second": second, "opcodes": opcodes}

    def build_filtering_constraints(self, obj_type, kwargs):
        model, constraints = models[obj_type], []
        for property in filtering_properties[obj_type]:
            value = kwargs.get(f"form[{property}]")
            if not value:
                continue
            filter = kwargs.get(f"form[{property}_filter]")
            if value in ("bool-true", "bool-false"):
                constraint = getattr(model, property) == (value == "bool-true")
            elif filter == "equality":
                constraint = getattr(model, property) == value
            elif filter == "inclusion" or DIALECT == "sqlite":
                constraint = getattr(model, property).contains(value)
            else:
                regex_operator = "regexp" if DIALECT == "mysql" else "~"
                constraint = getattr(model, property).op(regex_operator)(value)
            constraints.append(constraint)
        for related_model, relation_properties in relationships[
                obj_type].items():
            relation_ids = [
                int(id) for id in kwargs.getlist(f"form[{related_model}][]")
            ]
            filter = kwargs.get(f"form[{related_model}_filter]")
            if filter == "none":
                constraint = ~getattr(model, related_model).any()
            elif not relation_ids:
                continue
            elif relation_properties["list"]:
                constraint = getattr(model, related_model).any(
                    models[relation_properties["model"]].id.in_(relation_ids))
                if filter == "not_any":
                    constraint = ~constraint
            else:
                constraint = or_(
                    getattr(model, related_model).has(id=relation_id)
                    for relation_id in relation_ids)
            constraints.append(constraint)
        return constraints

    def multiselect_filtering(self, type, params):
        model = models[type]
        results = Session.query(model.id, model.name).filter(
            model.name.contains(params.get("term")))
        return {
            "items": [{
                "text": r.name,
                "id": str(r.id)
            } for r in results.limit(10).offset((int(params["page"]) - 1) *
                                                10).all()],
            "total_count":
            results.count(),
        }

    def table_filtering(self, table, kwargs):
        model, properties = models[table], table_properties[table]
        operator = and_ if kwargs.get("form[operator]",
                                      "all") == "all" else or_
        column_index = int(kwargs["order[0][column]"])
        if column_index < len(properties):
            order_property = getattr(model, properties[column_index])
            order_function = getattr(order_property, kwargs["order[0][dir]"],
                                     None)
        else:
            order_function = None
        constraints = self.build_filtering_constraints(table, kwargs)
        if table == "result":
            constraints.append(
                getattr(
                    models["result"],
                    "service" if "service" in kwargs["instance[type]"] else
                    kwargs["instance[type]"],
                ).has(id=kwargs["instance[id]"]))
            if kwargs.get("service[runtime]"):
                constraints.append(models["result"].parent_runtime ==
                                   kwargs.get("service[runtime]"))
        elif table == "configuration" and kwargs.get("instance[id]"):
            constraints.append(
                getattr(models[table],
                        "device").has(id=kwargs["instance[id]"]))
        result = Session.query(model).filter(operator(*constraints))
        if order_function:
            result = result.order_by(order_function())
        return {
            "draw":
            int(kwargs["draw"]),
            "recordsTotal":
            Session.query(func.count(model.id)).scalar(),
            "recordsFiltered":
            get_query_count(result),
            "data": [[
                getattr(obj, f"table_{property}", getattr(obj, property))
                for property in properties
            ] + obj.generate_row(table) for obj in result.limit(
                int(kwargs["length"])).offset(int(kwargs["start"])).all()],
        }

    def allowed_file(self, name, allowed_modules):
        allowed_syntax = "." in name
        allowed_extension = name.rsplit(".", 1)[1].lower() in allowed_modules
        return allowed_syntax and allowed_extension

    def get_time(self):
        return str(datetime.now())

    def send_email(
        self,
        subject,
        content,
        sender=None,
        recipients=None,
        filename=None,
        file_content=None,
    ):
        sender = sender or self.mail_sender
        recipients = recipients or self.mail_recipients
        message = MIMEMultipart()
        message["From"] = sender
        message["To"] = recipients
        message["Date"] = formatdate(localtime=True)
        message["Subject"] = subject
        message.attach(MIMEText(content))
        if filename:
            attached_file = MIMEApplication(file_content, Name=filename)
            attached_file[
                "Content-Disposition"] = f'attachment; filename="{filename}"'
            message.attach(attached_file)
        server = SMTP(self.mail_server, self.mail_port)
        if self.mail_use_tls:
            server.starttls()
            server.login(self.mail_username, self.mail_password)
        server.sendmail(sender, recipients.split(","), message.as_string())
        server.close()

    def str_dict(self, input, depth=0):
        tab = "\t" * depth
        if isinstance(input, list):
            result = "\n"
            for element in input:
                result += f"{tab}- {self.str_dict(element, depth + 1)}\n"
            return result
        elif isinstance(input, dict):
            result = ""
            for key, value in input.items():
                result += f"\n{tab}{key}: {self.str_dict(value, depth + 1)}"
            return result
        else:
            return str(input)

    def strip_all(self, input):
        return input.translate(str.maketrans("", "", f"{punctuation} "))

    def update_database_configurations_from_git(self):
        for dir in scandir(self.path / "git" / "configurations"):
            if dir.name == ".git":
                continue
            device = fetch("device", allow_none=True, name=dir.name)
            if device:
                with open(Path(dir.path) / "data.yml") as data:
                    parameters = yaml.load(data)
                    device.update(**{"dont_update_pools": True, **parameters})
                    config_file = Path(dir.path) / dir.name
                    if not config_file.exists():
                        continue
                    with open(config_file) as f:
                        device.configuration = device.configurations[str(
                            parameters["last_update"])] = f.read()
        Session.commit()
        for pool in fetch_all("pool"):
            if pool.device_configuration:
                pool.compute_pool()
Пример #6
0
 def init_syslog_server(self):
     self.syslog_server = SyslogServer(self.syslog_addr, self.syslog_port)
     self.syslog_server.start()