示例#1
0
    def __init__(self, manager):
        self.manager = manager

        self.db_dialect = os.environ.get("CN_SQL_DB_DIALECT", "mysql")
        self.schema_files = [
            "/app/static/jans_schema.json",
            "/app/static/custom_schema.json",
        ]

        self.client = SQLClient()

        with open("/app/static/sql/sql_data_types.json") as f:
            self.sql_data_types = json.loads(f.read())

        self.attr_types = []
        for fn in self.schema_files:
            with open(fn) as f:
                schema = json.loads(f.read())
            self.attr_types += schema["attributeTypes"]

        with open("/app/static/sql/opendj_attributes_syntax.json") as f:
            self.opendj_attr_types = json.loads(f.read())

        with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f:
            self.sql_data_types_mapping = json.loads(f.read())

        if self.db_dialect == "mysql":
            index_fn = "mysql_index.json"
        else:
            index_fn = "postgresql_index.json"

        with open(f"/app/static/sql/{index_fn}") as f:
            self.sql_indexes = json.loads(f.read())
def test_sql_client_getattr(monkeypatch, gmanager, tmpdir):
    from jans.pycloudlib.persistence.sql import SQLClient

    monkeypatch.setenv("CN_SQL_DB_DIALECT", "mysql")

    src = tmpdir.join("sql_password")
    src.write("secret")
    monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src))

    client = SQLClient(gmanager)
    assert client.__getattr__("create_table")
def test_sql_client_getattr_error(monkeypatch, gmanager, tmpdir):
    from jans.pycloudlib.persistence.sql import SQLClient

    monkeypatch.setenv("CN_SQL_DB_DIALECT", "mysql")

    src = tmpdir.join("sql_password")
    src.write("secret")
    monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src))

    client = SQLClient(gmanager)
    with pytest.raises(AttributeError):
        assert client.__getattr__("random_attr")
示例#4
0
def wait_for_sql_conn(manager, **kwargs):
    """Wait for readiness/liveness of an SQL database connection.
    """
    # checking connection
    init = SQLClient().connected()
    if not init:
        raise WaitError("SQL backend is unreachable")
示例#5
0
def wait_for_sql(manager, **kwargs):
    """Wait for readiness/liveness of an SQL database.
    """
    init = SQLClient().row_exists("jansClnt",
                                  manager.config.get("jca_client_id"))

    if not init:
        raise WaitError("SQL is not fully initialized")
def test_sql_client_init(monkeypatch, dialect, gmanager, tmpdir):
    from jans.pycloudlib.persistence.sql import SQLClient

    monkeypatch.setenv("CN_SQL_DB_DIALECT", dialect)

    src = tmpdir.join("sql_password")
    src.write("secret")
    monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src))

    client = SQLClient(gmanager)
    assert client.adapter.dialect == dialect
示例#7
0
class SqlPersistence:
    def __init__(self, manager):
        self.client = SQLClient()

    def get_auth_config(self):
        config = self.client.get(
            "jansAppConf",
            "jans-auth",
            ["jansConfDyn"],
        )
        return config.get("jansConfDyn", "")
示例#8
0
class SqlPersistence(BasePersistence):
    def __init__(self, manager):
        self.client = SQLClient()

    def get_auth_config(self):
        config = self.client.get(
            "jansAppConf",
            "jans-auth",
            ["jansRevision", "jansConfDyn"],
        )
        if not config:
            return {}

        config["id"] = "jans-auth"
        return config

    def modify_auth_config(self, id_, rev, conf_dynamic):
        modified = self.client.update("jansAppConf", id_, {
            "jansRevision": rev,
            "jansConfDyn": json.dumps(conf_dynamic)
        })
        return modified
 def __init__(self, manager):
     self.client = SQLClient()
示例#10
0
 def __init__(self, manager):
     super().__init__()
     self.manager = manager
     self.client = SQLClient()
     self.type = "sql"
示例#11
0
class SQLBackend(BaseBackend):
    def __init__(self, manager):
        super().__init__()
        self.manager = manager
        self.client = SQLClient()
        self.type = "sql"

    def get_entry(self, key, filter_="", attrs=None, **kwargs):
        table_name = kwargs.get("table_name")
        entry = self.client.get(table_name, key, attrs)

        if not entry:
            return None
        return Entry(key, entry)

    def modify_entry(self, key, attrs=None, **kwargs):
        attrs = attrs or {}
        table_name = kwargs.get("table_name")
        return self.client.update(table_name, key, attrs), ""

    def update_people_entries(self):
        # add jansAdminUIRole to default admin user
        admin_inum = self.manager.config.get("admin_inum")
        id_ = doc_id_from_dn(f"inum={admin_inum},ou=people,o=jans")
        kwargs = {"table_name": "jansPerson"}

        entry = self.get_entry(id_, **kwargs)
        if not entry:
            return

        # sql entry may have empty jansAdminUIRole hash ({"v": []})
        if not entry.attrs["jansAdminUIRole"]["v"]:
            entry.attrs["jansAdminUIRole"] = {"v": ["api-admin"]}
            self.modify_entry(id_, entry.attrs, **kwargs)

    def update_scopes_entries(self):
        # add jansAdminUIRole claim to profile scope
        id_ = doc_id_from_dn(self.jans_admin_ui_role_id)
        kwargs = {"table_name": "jansScope"}

        entry = self.get_entry(id_, **kwargs)

        if not entry:
            return

        if self.jans_admin_ui_claim not in entry.attrs["jansClaim"]["v"]:
            entry.attrs["jansClaim"]["v"].append(self.jans_admin_ui_claim)
            self.modify_entry(id_, entry.attrs, **kwargs)

    def update_clients_entries(self):
        jca_client_id = self.manager.config.get("jca_client_id")
        id_ = doc_id_from_dn(f"inum={jca_client_id},ou=clients,o=jans")
        kwargs = {"table_name": "jansClnt"}

        entry = self.get_entry(id_, **kwargs)

        if not entry:
            return

        should_update = False

        # modify redirect UI of config-api client
        hostname = self.manager.config.get("hostname")

        if f"https://{hostname}/admin" not in entry.attrs["jansRedirectURI"][
                "v"]:
            entry.attrs["jansRedirectURI"]["v"].append(
                f"https://{hostname}/admin")
            should_update = True

        # add jans_stat, SCIM users.read, SCIM users.write scopes to config-api client
        for scope in (self.jans_scim_scopes + self.jans_stat_scopes):
            if scope not in entry.attrs["jansScope"]["v"]:
                entry.attrs["jansScope"]["v"].append(scope)
                should_update = True

        if should_update:
            self.modify_entry(id_, entry.attrs, **kwargs)

    def update_scim_scopes_entries(self):
        # add jansAttrs to SCIM users.read and users.write scopes
        ids = [doc_id_from_dn(scope) for scope in self.jans_scim_scopes]
        kwargs = {"table_name": "jansScope"}

        for id_ in ids:
            entry = self.get_entry(id_, **kwargs)
            if not entry:
                continue

            if "jansAttrs" not in entry.attrs:
                entry.attrs["jansAttrs"] = self.jans_attrs
                self.modify_entry(id_, entry.attrs, **kwargs)

    def update_base_entries(self):
        # add jansManagerGrp to base entry
        id_ = doc_id_from_dn(JANS_BASE_ID)
        kwargs = {"table_name": "jansOrganization"}

        entry = self.get_entry(id_, **kwargs)
        if not entry:
            return

        if not entry.attrs.get("jansManagerGrp"):
            entry.attrs["jansManagerGrp"] = JANS_MANAGER_GROUP
            self.modify_entry(id_, entry.attrs, **kwargs)
示例#12
0
class SQLBackend:
    def __init__(self, manager):
        self.manager = manager

        self.db_dialect = os.environ.get("CN_SQL_DB_DIALECT", "mysql")
        self.schema_files = [
            "/app/static/jans_schema.json",
            "/app/static/custom_schema.json",
        ]

        self.client = SQLClient()

        with open("/app/static/sql/sql_data_types.json") as f:
            self.sql_data_types = json.loads(f.read())

        self.attr_types = []
        for fn in self.schema_files:
            with open(fn) as f:
                schema = json.loads(f.read())
            self.attr_types += schema["attributeTypes"]

        with open("/app/static/sql/opendj_attributes_syntax.json") as f:
            self.opendj_attr_types = json.loads(f.read())

        with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f:
            self.sql_data_types_mapping = json.loads(f.read())

        if self.db_dialect == "mysql":
            index_fn = "mysql_index.json"
        else:
            index_fn = "postgresql_index.json"

        with open(f"/app/static/sql/{index_fn}") as f:
            self.sql_indexes = json.loads(f.read())

    def get_attr_syntax(self, attr):
        for attr_type in self.attr_types:
            if attr not in attr_type["names"]:
                continue
            if attr_type.get("multivalued"):
                return "JSON"
            return attr_type["syntax"]

        # fallback to OpenDJ attribute type
        return self.opendj_attr_types.get(
            attr) or "1.3.6.1.4.1.1466.115.121.1.15"

    def get_data_type(self, attr, table=None):
        # check from SQL data types first
        type_def = self.sql_data_types.get(attr)

        if type_def:
            type_ = type_def.get(self.db_dialect) or type_def["mysql"]

            if table in type_.get("tables", {}):
                type_ = type_["tables"][table]

            data_type = type_["type"]
            if "size" in type_:
                data_type = f"{data_type}({type_['size']})"
            return data_type

        # data type is undefined, hence check from syntax
        syntax = self.get_attr_syntax(attr)
        syntax_def = self.sql_data_types_mapping[syntax]
        type_ = syntax_def.get(self.db_dialect) or syntax_def["mysql"]

        char_type = "VARCHAR"
        if self.db_dialect == "spanner":
            char_type = "STRING"

        if type_["type"] != char_type:
            data_type = type_["type"]
        else:
            if type_["size"] <= 127:
                data_type = f"{char_type}({type_['size']})"
            elif type_["size"] <= 255:
                data_type = "TINYTEXT" if self.db_dialect == "mysql" else "TEXT"
            else:
                data_type = "TEXT"

        if data_type == "TEXT" and self.db_dialect == "spanner":
            data_type = "STRING(MAX)"
        return data_type

    def create_tables(self):
        schemas = {}
        attrs = {}
        # cached schemas that holds table's column and its type
        table_columns = defaultdict(dict)

        for fn in self.schema_files:
            with open(fn) as f:
                schema = json.loads(f.read())

            for oc in schema["objectClasses"]:
                schemas[oc["names"][0]] = oc

            for attr in schema["attributeTypes"]:
                attrs[attr["names"][0]] = attr

        for table, oc in schemas.items():
            if oc.get("sql", {}).get("ignore"):
                continue

            # ``oc["may"]`` contains list of attributes
            if "sql" in oc:
                oc["may"] += oc["sql"].get("include", [])

                for inc_oc in oc["sql"].get("includeObjectClass", []):
                    oc["may"] += schemas[inc_oc]["may"]

            doc_id_type = self.get_data_type("doc_id", table)
            table_columns[table].update({
                "doc_id":
                doc_id_type,
                "objectClass":
                "VARCHAR(48)"
                if self.db_dialect != "spanner" else "STRING(48)",
                "dn":
                "VARCHAR(128)"
                if self.db_dialect != "spanner" else "STRING(128)",
            })

            # make sure ``oc["may"]`` doesn't have duplicate attribute
            for attr in set(oc["may"]):
                data_type = self.get_data_type(attr, table)
                table_columns[table].update({attr: data_type})

        for table, attr_mapping in table_columns.items():
            self.client.create_table(table, attr_mapping, "doc_id")

        # for name, attr in attrs.items():
        #     table = attr.get("sql", {}).get("add_table")
        #     logger.info(name)
        #     logger.info(table)
        #     if not table:
        #         continue

        #     data_type = self.get_data_type(name, table)
        #     col_def = f"{attr} {data_type}"

        #     sql_cmd = f"ALTER TABLE {table} ADD {col_def};"
        #     logger.info(sql_cmd)

    def get_index_fields(self, table_name):
        fields = self.sql_indexes.get(table_name, {}).get("fields", [])
        fields += self.sql_indexes["__common__"]["fields"]

        # make unique fields
        return list(set(fields))

    def create_mysql_indexes(self, table_name: str, column_mapping: dict):
        fields = self.get_index_fields(table_name)

        for column_name, column_type in column_mapping.items():
            if column_name == "doc_id" or column_name not in fields:
                continue

            index_name = f"{table_name}_{FIELD_RE.sub('_', column_name)}"

            if column_type.lower() != "json":
                query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(table_name)} ({self.client.quoted_id(column_name)})"
                self.client.create_index(query)
            else:
                # TODO: revise JSON type
                #
                # some MySQL versions don't support JSON array (NotSupportedError)
                # also some of them don't support functional index that returns
                # JSON or Geometry value
                for i, index_str in enumerate(
                        self.sql_indexes["__common__"]["JSON"], start=1):
                    index_str_fmt = Template(index_str).safe_substitute({
                        "field":
                        column_name,
                        "data_type":
                        column_type,
                    })
                    name = f"{table_name}_json_{i}"
                    query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({index_str_fmt}))"
                    self.client.create_index(query)

        for i, custom in enumerate(self.sql_indexes.get(table_name,
                                                        {}).get("custom", []),
                                   start=1):
            # jansPerson table has unsupported custom index expressions that need to be skipped if mysql < 8.0
            if table_name == "jansPerson" and self.client.server_version < "8.0":
                continue
            name = f"{table_name}_CustomIdx{i}"
            query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} ({custom})"
            self.client.create_index(query)

    def create_pgsql_indexes(self, table_name: str, column_mapping: dict):
        fields = self.get_index_fields(table_name)

        for column_name, column_type in column_mapping.items():
            if column_name == "doc_id" or column_name not in fields:
                continue

            index_name = f"{table_name}_{FIELD_RE.sub('_', column_name)}"

            if column_type.lower() != "json":
                query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(table_name)} ({self.client.quoted_id(column_name)})"
                self.client.create_index(query)
            else:
                for i, index_str in enumerate(
                        self.sql_indexes["__common__"]["JSON"], start=1):
                    index_str_fmt = Template(index_str).safe_substitute({
                        "field":
                        column_name,
                        "data_type":
                        column_type,
                    })
                    name = f"{table_name}_json_{i}"
                    query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({index_str_fmt}))"
                    self.client.create_index(query)

        for i, custom in enumerate(self.sql_indexes.get(table_name,
                                                        {}).get("custom", []),
                                   start=1):
            name = f"{table_name}_custom_{i}"
            query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({custom}))"
            self.client.create_index(query)

    def create_indexes(self):
        for table_name, column_mapping in self.client.get_table_mapping(
        ).items():
            if self.db_dialect == "pgsql":
                index_func = self.create_pgsql_indexes
            elif self.db_dialect == "mysql":
                index_func = self.create_mysql_indexes
            # run the callback
            index_func(table_name, column_mapping)

    def import_ldif(self):
        optional_scopes = json.loads(
            self.manager.config.get("optional_scopes", "[]"))
        ldif_mappings = get_ldif_mappings(optional_scopes)

        ctx = prepare_template_ctx(self.manager)

        for _, files in ldif_mappings.items():
            for file_ in files:
                logger.info(f"Importing {file_} file")
                src = f"/app/templates/{file_}"
                dst = f"/app/tmp/{file_}"
                os.makedirs(os.path.dirname(dst), exist_ok=True)

                render_ldif(src, dst, ctx)

                for table_name, column_mapping in self.data_from_ldif(dst):
                    self.client.insert_into(table_name, column_mapping)

    def initialize(self):
        logger.info("Creating tables (if not exist)")
        self.create_tables()

        logger.info("Updating schema (if required)")
        self.update_schema()

        # force-reload metadata as we may have changed the schema
        self.client.adapter._metadata = None

        logger.info("Creating indexes (if not exist)")
        self.create_indexes()

        self.import_ldif()

    def transform_value(self, key, values):
        type_ = self.sql_data_types.get(key)

        if not type_:
            attr_syntax = self.get_attr_syntax(key)
            type_ = self.sql_data_types_mapping[attr_syntax]

        type_ = type_.get(self.db_dialect) or type_["mysql"]
        data_type = type_["type"]

        if data_type in (
                "SMALLINT",
                "BOOL",
        ):
            if values[0].lower() in ("1", "on", "true", "yes", "ok"):
                return 1 if data_type == "SMALLINT" else True
            return 0 if data_type == "SMALLINT" else False

        if data_type == "INT":
            return int(values[0])

        if data_type in (
                "DATETIME(3)",
                "TIMESTAMP",
        ):
            dval = values[0].strip("Z")
            sep = " "
            postfix = ""
            if self.db_dialect == "spanner":
                sep = "T"
                postfix = "Z"
            # return "{}-{}-{} {}:{}:{}{}".format(dval[0:4], dval[4:6], dval[6:8], dval[8:10], dval[10:12], dval[12:14], dval[14:17])
            return "{}-{}-{}{}{}:{}:{}{}{}".format(
                dval[0:4],
                dval[4:6],
                dval[6:8],
                sep,
                dval[8:10],
                dval[10:12],
                dval[12:14],
                dval[14:17],
                postfix,
            )

        if data_type == "JSON":
            # return json.dumps({"v": values})
            return {"v": values}

        if data_type == "ARRAY<STRING(MAX)>":
            return values

        # fallback
        return values[0]

    def data_from_ldif(self, filename):
        with open(filename, "rb") as fd:
            parser = LDIFParser(fd)

            for dn, entry in parser.parse():
                doc_id = doc_id_from_dn(dn)

                oc = entry.get("objectClass") or entry.get("objectclass")
                if oc:
                    if "top" in oc:
                        oc.remove("top")

                    if len(oc) == 1 and oc[0].lower() in ("organizationalunit",
                                                          "organization"):
                        continue

                table_name = oc[-1]

                if "objectClass" in entry:
                    entry.pop("objectClass")
                elif "objectclass" in entry:
                    entry.pop("objectclass")

                attr_mapping = OrderedDict({
                    "doc_id": doc_id,
                    "objectClass": table_name,
                    "dn": dn,
                })

                for attr in entry:
                    value = self.transform_value(attr, entry[attr])
                    attr_mapping[attr] = value
                yield table_name, attr_mapping

    def update_schema(self):
        table_mapping = self.client.get_table_mapping()

        # 1 - jansDefAcrValues is changed to multivalued (JSON type)
        table_name = "jansClnt"
        col_name = "jansDefAcrValues"
        old_data_type = table_mapping[table_name][col_name]
        data_type = self.get_data_type(col_name, table_name)

        if data_type != old_data_type:
            # get the value first before updating column type
            acr_values = {
                row["doc_id"]: row[col_name]
                for row in self.client.search(table_name, ["doc_id", col_name])
            }

            # to change the storage format of a JSON column, drop the column and
            # add the column back specifying the new storage format
            with self.client.adapter.engine.connect() as conn:
                conn.execute(
                    f"ALTER TABLE {self.client.quoted_id(table_name)} DROP COLUMN {self.client.quoted_id(col_name)}"
                )
                conn.execute(
                    f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}"
                )

            # force-reload metadata as we may have changed the schema before migrating old data
            self.client.adapter._metadata = None

            for doc_id, value in acr_values.items():
                if not value:
                    value_list = []
                else:
                    value_list = [value]
                self.client.update(table_name, doc_id,
                                   {col_name: {
                                       "v": value_list
                                   }})

        # 2 - jansUsrDN column must be in jansToken table
        table_name = "jansToken"
        col_name = "jansUsrDN"

        if col_name not in table_mapping[table_name]:
            data_type = self.get_data_type(col_name, table_name)
            with self.client.adapter.engine.connect() as conn:
                conn.execute(
                    f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}"
                )