def __init__(self, manager): self.manager = manager self.db_dialect = os.environ.get("CN_SQL_DB_DIALECT", "mysql") self.schema_files = [ "/app/static/jans_schema.json", "/app/static/custom_schema.json", ] self.client = SQLClient() with open("/app/static/sql/sql_data_types.json") as f: self.sql_data_types = json.loads(f.read()) self.attr_types = [] for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) self.attr_types += schema["attributeTypes"] with open("/app/static/sql/opendj_attributes_syntax.json") as f: self.opendj_attr_types = json.loads(f.read()) with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f: self.sql_data_types_mapping = json.loads(f.read()) if self.db_dialect == "mysql": index_fn = "mysql_index.json" else: index_fn = "postgresql_index.json" with open(f"/app/static/sql/{index_fn}") as f: self.sql_indexes = json.loads(f.read())
def test_sql_client_getattr(monkeypatch, gmanager, tmpdir): from jans.pycloudlib.persistence.sql import SQLClient monkeypatch.setenv("CN_SQL_DB_DIALECT", "mysql") src = tmpdir.join("sql_password") src.write("secret") monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src)) client = SQLClient(gmanager) assert client.__getattr__("create_table")
def test_sql_client_getattr_error(monkeypatch, gmanager, tmpdir): from jans.pycloudlib.persistence.sql import SQLClient monkeypatch.setenv("CN_SQL_DB_DIALECT", "mysql") src = tmpdir.join("sql_password") src.write("secret") monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src)) client = SQLClient(gmanager) with pytest.raises(AttributeError): assert client.__getattr__("random_attr")
def wait_for_sql_conn(manager, **kwargs): """Wait for readiness/liveness of an SQL database connection. """ # checking connection init = SQLClient().connected() if not init: raise WaitError("SQL backend is unreachable")
def wait_for_sql(manager, **kwargs): """Wait for readiness/liveness of an SQL database. """ init = SQLClient().row_exists("jansClnt", manager.config.get("jca_client_id")) if not init: raise WaitError("SQL is not fully initialized")
def test_sql_client_init(monkeypatch, dialect, gmanager, tmpdir): from jans.pycloudlib.persistence.sql import SQLClient monkeypatch.setenv("CN_SQL_DB_DIALECT", dialect) src = tmpdir.join("sql_password") src.write("secret") monkeypatch.setenv("CN_SQL_PASSWORD_FILE", str(src)) client = SQLClient(gmanager) assert client.adapter.dialect == dialect
class SqlPersistence: def __init__(self, manager): self.client = SQLClient() def get_auth_config(self): config = self.client.get( "jansAppConf", "jans-auth", ["jansConfDyn"], ) return config.get("jansConfDyn", "")
class SqlPersistence(BasePersistence): def __init__(self, manager): self.client = SQLClient() def get_auth_config(self): config = self.client.get( "jansAppConf", "jans-auth", ["jansRevision", "jansConfDyn"], ) if not config: return {} config["id"] = "jans-auth" return config def modify_auth_config(self, id_, rev, conf_dynamic): modified = self.client.update("jansAppConf", id_, { "jansRevision": rev, "jansConfDyn": json.dumps(conf_dynamic) }) return modified
def __init__(self, manager): self.client = SQLClient()
def __init__(self, manager): super().__init__() self.manager = manager self.client = SQLClient() self.type = "sql"
class SQLBackend(BaseBackend): def __init__(self, manager): super().__init__() self.manager = manager self.client = SQLClient() self.type = "sql" def get_entry(self, key, filter_="", attrs=None, **kwargs): table_name = kwargs.get("table_name") entry = self.client.get(table_name, key, attrs) if not entry: return None return Entry(key, entry) def modify_entry(self, key, attrs=None, **kwargs): attrs = attrs or {} table_name = kwargs.get("table_name") return self.client.update(table_name, key, attrs), "" def update_people_entries(self): # add jansAdminUIRole to default admin user admin_inum = self.manager.config.get("admin_inum") id_ = doc_id_from_dn(f"inum={admin_inum},ou=people,o=jans") kwargs = {"table_name": "jansPerson"} entry = self.get_entry(id_, **kwargs) if not entry: return # sql entry may have empty jansAdminUIRole hash ({"v": []}) if not entry.attrs["jansAdminUIRole"]["v"]: entry.attrs["jansAdminUIRole"] = {"v": ["api-admin"]} self.modify_entry(id_, entry.attrs, **kwargs) def update_scopes_entries(self): # add jansAdminUIRole claim to profile scope id_ = doc_id_from_dn(self.jans_admin_ui_role_id) kwargs = {"table_name": "jansScope"} entry = self.get_entry(id_, **kwargs) if not entry: return if self.jans_admin_ui_claim not in entry.attrs["jansClaim"]["v"]: entry.attrs["jansClaim"]["v"].append(self.jans_admin_ui_claim) self.modify_entry(id_, entry.attrs, **kwargs) def update_clients_entries(self): jca_client_id = self.manager.config.get("jca_client_id") id_ = doc_id_from_dn(f"inum={jca_client_id},ou=clients,o=jans") kwargs = {"table_name": "jansClnt"} entry = self.get_entry(id_, **kwargs) if not entry: return should_update = False # modify redirect UI of config-api client hostname = self.manager.config.get("hostname") if f"https://{hostname}/admin" not in entry.attrs["jansRedirectURI"][ "v"]: entry.attrs["jansRedirectURI"]["v"].append( f"https://{hostname}/admin") should_update = True # add jans_stat, SCIM users.read, SCIM users.write scopes to config-api client for scope in (self.jans_scim_scopes + self.jans_stat_scopes): if scope not in entry.attrs["jansScope"]["v"]: entry.attrs["jansScope"]["v"].append(scope) should_update = True if should_update: self.modify_entry(id_, entry.attrs, **kwargs) def update_scim_scopes_entries(self): # add jansAttrs to SCIM users.read and users.write scopes ids = [doc_id_from_dn(scope) for scope in self.jans_scim_scopes] kwargs = {"table_name": "jansScope"} for id_ in ids: entry = self.get_entry(id_, **kwargs) if not entry: continue if "jansAttrs" not in entry.attrs: entry.attrs["jansAttrs"] = self.jans_attrs self.modify_entry(id_, entry.attrs, **kwargs) def update_base_entries(self): # add jansManagerGrp to base entry id_ = doc_id_from_dn(JANS_BASE_ID) kwargs = {"table_name": "jansOrganization"} entry = self.get_entry(id_, **kwargs) if not entry: return if not entry.attrs.get("jansManagerGrp"): entry.attrs["jansManagerGrp"] = JANS_MANAGER_GROUP self.modify_entry(id_, entry.attrs, **kwargs)
class SQLBackend: def __init__(self, manager): self.manager = manager self.db_dialect = os.environ.get("CN_SQL_DB_DIALECT", "mysql") self.schema_files = [ "/app/static/jans_schema.json", "/app/static/custom_schema.json", ] self.client = SQLClient() with open("/app/static/sql/sql_data_types.json") as f: self.sql_data_types = json.loads(f.read()) self.attr_types = [] for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) self.attr_types += schema["attributeTypes"] with open("/app/static/sql/opendj_attributes_syntax.json") as f: self.opendj_attr_types = json.loads(f.read()) with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f: self.sql_data_types_mapping = json.loads(f.read()) if self.db_dialect == "mysql": index_fn = "mysql_index.json" else: index_fn = "postgresql_index.json" with open(f"/app/static/sql/{index_fn}") as f: self.sql_indexes = json.loads(f.read()) def get_attr_syntax(self, attr): for attr_type in self.attr_types: if attr not in attr_type["names"]: continue if attr_type.get("multivalued"): return "JSON" return attr_type["syntax"] # fallback to OpenDJ attribute type return self.opendj_attr_types.get( attr) or "1.3.6.1.4.1.1466.115.121.1.15" def get_data_type(self, attr, table=None): # check from SQL data types first type_def = self.sql_data_types.get(attr) if type_def: type_ = type_def.get(self.db_dialect) or type_def["mysql"] if table in type_.get("tables", {}): type_ = type_["tables"][table] data_type = type_["type"] if "size" in type_: data_type = f"{data_type}({type_['size']})" return data_type # data type is undefined, hence check from syntax syntax = self.get_attr_syntax(attr) syntax_def = self.sql_data_types_mapping[syntax] type_ = syntax_def.get(self.db_dialect) or syntax_def["mysql"] char_type = "VARCHAR" if self.db_dialect == "spanner": char_type = "STRING" if type_["type"] != char_type: data_type = type_["type"] else: if type_["size"] <= 127: data_type = f"{char_type}({type_['size']})" elif type_["size"] <= 255: data_type = "TINYTEXT" if self.db_dialect == "mysql" else "TEXT" else: data_type = "TEXT" if data_type == "TEXT" and self.db_dialect == "spanner": data_type = "STRING(MAX)" return data_type def create_tables(self): schemas = {} attrs = {} # cached schemas that holds table's column and its type table_columns = defaultdict(dict) for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) for oc in schema["objectClasses"]: schemas[oc["names"][0]] = oc for attr in schema["attributeTypes"]: attrs[attr["names"][0]] = attr for table, oc in schemas.items(): if oc.get("sql", {}).get("ignore"): continue # ``oc["may"]`` contains list of attributes if "sql" in oc: oc["may"] += oc["sql"].get("include", []) for inc_oc in oc["sql"].get("includeObjectClass", []): oc["may"] += schemas[inc_oc]["may"] doc_id_type = self.get_data_type("doc_id", table) table_columns[table].update({ "doc_id": doc_id_type, "objectClass": "VARCHAR(48)" if self.db_dialect != "spanner" else "STRING(48)", "dn": "VARCHAR(128)" if self.db_dialect != "spanner" else "STRING(128)", }) # make sure ``oc["may"]`` doesn't have duplicate attribute for attr in set(oc["may"]): data_type = self.get_data_type(attr, table) table_columns[table].update({attr: data_type}) for table, attr_mapping in table_columns.items(): self.client.create_table(table, attr_mapping, "doc_id") # for name, attr in attrs.items(): # table = attr.get("sql", {}).get("add_table") # logger.info(name) # logger.info(table) # if not table: # continue # data_type = self.get_data_type(name, table) # col_def = f"{attr} {data_type}" # sql_cmd = f"ALTER TABLE {table} ADD {col_def};" # logger.info(sql_cmd) def get_index_fields(self, table_name): fields = self.sql_indexes.get(table_name, {}).get("fields", []) fields += self.sql_indexes["__common__"]["fields"] # make unique fields return list(set(fields)) def create_mysql_indexes(self, table_name: str, column_mapping: dict): fields = self.get_index_fields(table_name) for column_name, column_type in column_mapping.items(): if column_name == "doc_id" or column_name not in fields: continue index_name = f"{table_name}_{FIELD_RE.sub('_', column_name)}" if column_type.lower() != "json": query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(table_name)} ({self.client.quoted_id(column_name)})" self.client.create_index(query) else: # TODO: revise JSON type # # some MySQL versions don't support JSON array (NotSupportedError) # also some of them don't support functional index that returns # JSON or Geometry value for i, index_str in enumerate( self.sql_indexes["__common__"]["JSON"], start=1): index_str_fmt = Template(index_str).safe_substitute({ "field": column_name, "data_type": column_type, }) name = f"{table_name}_json_{i}" query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({index_str_fmt}))" self.client.create_index(query) for i, custom in enumerate(self.sql_indexes.get(table_name, {}).get("custom", []), start=1): # jansPerson table has unsupported custom index expressions that need to be skipped if mysql < 8.0 if table_name == "jansPerson" and self.client.server_version < "8.0": continue name = f"{table_name}_CustomIdx{i}" query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} ({custom})" self.client.create_index(query) def create_pgsql_indexes(self, table_name: str, column_mapping: dict): fields = self.get_index_fields(table_name) for column_name, column_type in column_mapping.items(): if column_name == "doc_id" or column_name not in fields: continue index_name = f"{table_name}_{FIELD_RE.sub('_', column_name)}" if column_type.lower() != "json": query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(table_name)} ({self.client.quoted_id(column_name)})" self.client.create_index(query) else: for i, index_str in enumerate( self.sql_indexes["__common__"]["JSON"], start=1): index_str_fmt = Template(index_str).safe_substitute({ "field": column_name, "data_type": column_type, }) name = f"{table_name}_json_{i}" query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({index_str_fmt}))" self.client.create_index(query) for i, custom in enumerate(self.sql_indexes.get(table_name, {}).get("custom", []), start=1): name = f"{table_name}_custom_{i}" query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} (({custom}))" self.client.create_index(query) def create_indexes(self): for table_name, column_mapping in self.client.get_table_mapping( ).items(): if self.db_dialect == "pgsql": index_func = self.create_pgsql_indexes elif self.db_dialect == "mysql": index_func = self.create_mysql_indexes # run the callback index_func(table_name, column_mapping) def import_ldif(self): optional_scopes = json.loads( self.manager.config.get("optional_scopes", "[]")) ldif_mappings = get_ldif_mappings(optional_scopes) ctx = prepare_template_ctx(self.manager) for _, files in ldif_mappings.items(): for file_ in files: logger.info(f"Importing {file_} file") src = f"/app/templates/{file_}" dst = f"/app/tmp/{file_}" os.makedirs(os.path.dirname(dst), exist_ok=True) render_ldif(src, dst, ctx) for table_name, column_mapping in self.data_from_ldif(dst): self.client.insert_into(table_name, column_mapping) def initialize(self): logger.info("Creating tables (if not exist)") self.create_tables() logger.info("Updating schema (if required)") self.update_schema() # force-reload metadata as we may have changed the schema self.client.adapter._metadata = None logger.info("Creating indexes (if not exist)") self.create_indexes() self.import_ldif() def transform_value(self, key, values): type_ = self.sql_data_types.get(key) if not type_: attr_syntax = self.get_attr_syntax(key) type_ = self.sql_data_types_mapping[attr_syntax] type_ = type_.get(self.db_dialect) or type_["mysql"] data_type = type_["type"] if data_type in ( "SMALLINT", "BOOL", ): if values[0].lower() in ("1", "on", "true", "yes", "ok"): return 1 if data_type == "SMALLINT" else True return 0 if data_type == "SMALLINT" else False if data_type == "INT": return int(values[0]) if data_type in ( "DATETIME(3)", "TIMESTAMP", ): dval = values[0].strip("Z") sep = " " postfix = "" if self.db_dialect == "spanner": sep = "T" postfix = "Z" # return "{}-{}-{} {}:{}:{}{}".format(dval[0:4], dval[4:6], dval[6:8], dval[8:10], dval[10:12], dval[12:14], dval[14:17]) return "{}-{}-{}{}{}:{}:{}{}{}".format( dval[0:4], dval[4:6], dval[6:8], sep, dval[8:10], dval[10:12], dval[12:14], dval[14:17], postfix, ) if data_type == "JSON": # return json.dumps({"v": values}) return {"v": values} if data_type == "ARRAY<STRING(MAX)>": return values # fallback return values[0] def data_from_ldif(self, filename): with open(filename, "rb") as fd: parser = LDIFParser(fd) for dn, entry in parser.parse(): doc_id = doc_id_from_dn(dn) oc = entry.get("objectClass") or entry.get("objectclass") if oc: if "top" in oc: oc.remove("top") if len(oc) == 1 and oc[0].lower() in ("organizationalunit", "organization"): continue table_name = oc[-1] if "objectClass" in entry: entry.pop("objectClass") elif "objectclass" in entry: entry.pop("objectclass") attr_mapping = OrderedDict({ "doc_id": doc_id, "objectClass": table_name, "dn": dn, }) for attr in entry: value = self.transform_value(attr, entry[attr]) attr_mapping[attr] = value yield table_name, attr_mapping def update_schema(self): table_mapping = self.client.get_table_mapping() # 1 - jansDefAcrValues is changed to multivalued (JSON type) table_name = "jansClnt" col_name = "jansDefAcrValues" old_data_type = table_mapping[table_name][col_name] data_type = self.get_data_type(col_name, table_name) if data_type != old_data_type: # get the value first before updating column type acr_values = { row["doc_id"]: row[col_name] for row in self.client.search(table_name, ["doc_id", col_name]) } # to change the storage format of a JSON column, drop the column and # add the column back specifying the new storage format with self.client.adapter.engine.connect() as conn: conn.execute( f"ALTER TABLE {self.client.quoted_id(table_name)} DROP COLUMN {self.client.quoted_id(col_name)}" ) conn.execute( f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}" ) # force-reload metadata as we may have changed the schema before migrating old data self.client.adapter._metadata = None for doc_id, value in acr_values.items(): if not value: value_list = [] else: value_list = [value] self.client.update(table_name, doc_id, {col_name: { "v": value_list }}) # 2 - jansUsrDN column must be in jansToken table table_name = "jansToken" col_name = "jansUsrDN" if col_name not in table_mapping[table_name]: data_type = self.get_data_type(col_name, table_name) with self.client.adapter.engine.connect() as conn: conn.execute( f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}" )