def wait_for_spanner_conn(manager, **kwargs): """Wait for readiness/liveness of an Spanner database connection. """ # checking connection init = SpannerClient().connected() if not init: raise WaitError("Spanner backend is unreachable")
def wait_for_spanner(manager, **kwargs): """Wait for readiness/liveness of an Spanner database. """ init = SpannerClient().row_exists("jansClnt", manager.config.get("jca_client_id")) if not init: raise WaitError("Spanner is not fully initialized")
def __init__(self, manager): self.manager = manager self.db_dialect = "spanner" self.schema_files = [ "/app/static/jans_schema.json", "/app/static/custom_schema.json", ] self.client = SpannerClient() with open("/app/static/sql/sql_data_types.json") as f: self.sql_data_types = json.loads(f.read()) self.attr_types = [] for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) self.attr_types += schema["attributeTypes"] with open("/app/static/sql/opendj_attributes_syntax.json") as f: self.opendj_attr_types = json.loads(f.read()) with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f: self.sql_data_types_mapping = json.loads(f.read()) index_fn = "spanner_index.json" with open(f"/app/static/sql/{index_fn}") as f: self.sql_indexes = json.loads(f.read()) # with open("/app/static/couchbase/index.json") as f: # # prefix = os.environ.get("CN_COUCHBASE_BUCKET_PREFIX", "jans") # txt = f.read() # .replace("!bucket_prefix!", prefix) # self.cb_indexes = json.loads(txt) with open("/app/static/sql/sub_tables.json") as f: self.sub_tables = json.loads(f.read()).get(self.db_dialect) or {}
def __init__(self, manager): self.client = SpannerClient()
class SpannerBackend: def __init__(self, manager): self.manager = manager self.db_dialect = "spanner" self.schema_files = [ "/app/static/jans_schema.json", "/app/static/custom_schema.json", ] self.client = SpannerClient() with open("/app/static/sql/sql_data_types.json") as f: self.sql_data_types = json.loads(f.read()) self.attr_types = [] for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) self.attr_types += schema["attributeTypes"] with open("/app/static/sql/opendj_attributes_syntax.json") as f: self.opendj_attr_types = json.loads(f.read()) with open("/app/static/sql/ldap_sql_data_type_mapping.json") as f: self.sql_data_types_mapping = json.loads(f.read()) index_fn = "spanner_index.json" with open(f"/app/static/sql/{index_fn}") as f: self.sql_indexes = json.loads(f.read()) # with open("/app/static/couchbase/index.json") as f: # # prefix = os.environ.get("CN_COUCHBASE_BUCKET_PREFIX", "jans") # txt = f.read() # .replace("!bucket_prefix!", prefix) # self.cb_indexes = json.loads(txt) with open("/app/static/sql/sub_tables.json") as f: self.sub_tables = json.loads(f.read()).get(self.db_dialect) or {} def get_attr_syntax(self, attr): for attr_type in self.attr_types: if attr not in attr_type["names"]: continue if attr_type.get("multivalued"): return "JSON" return attr_type["syntax"] # fallback to OpenDJ attribute type return self.opendj_attr_types.get( attr) or "1.3.6.1.4.1.1466.115.121.1.15" def get_data_type(self, attr, table=None): # check from SQL data types first type_def = self.sql_data_types.get(attr) if type_def: type_ = type_def.get(self.db_dialect) # or type_def["mysql"] if table in type_.get("tables", {}): type_ = type_["tables"][table] data_type = type_["type"] if "size" in type_: data_type = f"{data_type}({type_['size']})" return data_type # data type is undefined, hence check from syntax syntax = self.get_attr_syntax(attr) syntax_def = self.sql_data_types_mapping[syntax] type_ = syntax_def.get(self.db_dialect) # or syntax_def["mysql"] # char_type = "VARCHAR" # if self.db_dialect == "spanner": char_type = "STRING" if type_["type"] != char_type: # not STRING data_type = type_["type"] else: # if "size" in type_: # size = type_["size"] # # data_type = f"{char_type}(type['size'])" # else: # # data_type = "STRING(MAX)" # size = "MAX" size = type_.get("size") or "MAX" data_type = f"{char_type}({size})" # if type_["size"] <= 127: # data_type = f"{char_type}({type_['size']})" # elif type_["size"] <= 255: # data_type = "TINYTEXT" if self.db_dialect == "mysql" else "TEXT" # else: # data_type = "TEXT" # if data_type == "TEXT" and self.db_dialect == "spanner": # data_type = "STRING(MAX)" return data_type def create_tables(self): schemas = {} attrs = {} # cached schemas that holds table's column and its type table_columns = defaultdict(dict) for fn in self.schema_files: with open(fn) as f: schema = json.loads(f.read()) for oc in schema["objectClasses"]: schemas[oc["names"][0]] = oc for attr in schema["attributeTypes"]: attrs[attr["names"][0]] = attr for table, oc in schemas.items(): if oc.get("sql", {}).get("ignore"): continue # ``oc["may"]`` contains list of attributes if "sql" in oc: oc["may"] += oc["sql"].get("include", []) for inc_oc in oc["sql"].get("includeObjectClass", []): oc["may"] += schemas[inc_oc]["may"] doc_id_type = self.get_data_type("doc_id", table) table_columns[table].update({ "doc_id": doc_id_type, "objectClass": "STRING(48)", "dn": "STRING(128)", }) # make sure ``oc["may"]`` doesn't have duplicate attribute for attr in set(oc["may"]): data_type = self.get_data_type(attr, table) table_columns[table].update({attr: data_type}) for table, attr_mapping in table_columns.items(): self.client.create_table(table, attr_mapping, "doc_id") # for name, attr in attrs.items(): # table = attr.get("sql", {}).get("add_table") # logger.info(name) # logger.info(table) # if not table: # continue # data_type = self.get_data_type(name, table) # col_def = f"{attr} {data_type}" # sql_cmd = f"ALTER TABLE {table} ADD {col_def};" # logger.info(sql_cmd) # def _fields_from_cb_indexes(self): # fields = [] # for _, data in self.cb_indexes.items(): # # extract and flatten # attrs = list(itertools.chain.from_iterable(data["attributes"])) # fields += attrs # for static in data["static"]: # attrs = [ # attr for attr in static[0] # if "(" not in attr # ] # fields += attrs # fields = list(set(fields)) # # exclude objectClass # if "objectClass" in fields: # fields.remove("objectClass") # return fields def get_index_fields(self, table_name): # cb_fields = self._fields_from_cb_indexes() fields = self.sql_indexes.get(table_name, {}).get("fields", []) fields += self.sql_indexes["__common__"]["fields"] # fields += cb_fields # make unique fields return list(set(fields)) def create_spanner_indexes(self, table_name: str, column_mapping: dict): fields = self.get_index_fields(table_name) for column_name, column_type in column_mapping.items(): if column_name == "doc_id" or column_name not in fields: continue index_name = f"{table_name}_{FIELD_RE.sub('_', column_name)}" if column_type.lower() != "array": query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(table_name)} ({self.client.quoted_id(column_name)})" self.client.create_index(query) else: # TODO: how to create index for ARRAY? pass custom_indexes = self.sql_indexes.get(table_name, {}).get("custom", []) for i, custom in enumerate(custom_indexes, start=1): name = f"{table_name}_CustomIdx_{i}" query = f"CREATE INDEX {self.client.quoted_id(name)} ON {self.client.quoted_id(table_name)} ({custom})" self.client.create_index(query) def create_indexes(self): for table_name, column_mapping in self.client.get_table_mapping( ).items(): # run the callback self.create_spanner_indexes(table_name, column_mapping) def import_ldif(self): optional_scopes = json.loads( self.manager.config.get("optional_scopes", "[]")) ldif_mappings = get_ldif_mappings(optional_scopes) ctx = prepare_template_ctx(self.manager) for _, files in ldif_mappings.items(): for file_ in files: logger.info(f"Importing {file_} file") src = f"/app/templates/{file_}" dst = f"/app/tmp/{file_}" os.makedirs(os.path.dirname(dst), exist_ok=True) render_ldif(src, dst, ctx) for table_name, column_mapping in self.data_from_ldif(dst): self.client.insert_into(table_name, column_mapping) # inject rows into subtable (if any) self.insert_into_subtable(table_name, column_mapping) def initialize(self): logger.info("Creating tables (if not exist)") self.create_tables() self.create_subtables() logger.info("Updating schema (if required)") self.update_schema() logger.info("Creating indexes (if not exist)") self.create_indexes() self.import_ldif() def transform_value(self, key, values): type_ = self.sql_data_types.get(key) if not type_: attr_syntax = self.get_attr_syntax(key) type_ = self.sql_data_types_mapping[attr_syntax] type_ = type_.get(self.db_dialect) # or type_["mysql"] data_type = type_["type"] if data_type in ( "SMALLINT", "BOOL", ): if values[0].lower() in ("1", "on", "true", "yes", "ok"): return 1 if data_type == "SMALLINT" else True return 0 if data_type == "SMALLINT" else False if data_type == "INT": return int(values[0]) if data_type in ( "DATETIME(3)", "TIMESTAMP", ): dval = values[0].strip("Z") # sep = " " # postfix = "" # if self.db_dialect == "spanner": sep = "T" postfix = "Z" # return "{}-{}-{} {}:{}:{}{}".format(dval[0:4], dval[4:6], dval[6:8], dval[8:10], dval[10:12], dval[12:14], dval[14:17]) return "{}-{}-{}{}{}:{}:{}{}{}".format( dval[0:4], dval[4:6], dval[6:8], sep, dval[8:10], dval[10:12], dval[12:14], dval[14:17], postfix, ) if data_type == "JSON": # return json.dumps({"v": values}) return {"v": values} if data_type == "ARRAY<STRING(MAX)>": return values # fallback return values[0] def data_from_ldif(self, filename): with open(filename, "rb") as fd: parser = LDIFParser(fd) for dn, entry in parser.parse(): doc_id = doc_id_from_dn(dn) oc = entry.get("objectClass") or entry.get("objectclass") if oc: if "top" in oc: oc.remove("top") if len(oc) == 1 and oc[0].lower() in ("organizationalunit", "organization"): continue table_name = oc[-1] # entry.pop(rdn_name) if "objectClass" in entry: entry.pop("objectClass") elif "objectclass" in entry: entry.pop("objectclass") attr_mapping = OrderedDict({ "doc_id": doc_id, "objectClass": table_name, "dn": dn, }) for attr in entry: # TODO: check if attr in sub table value = self.transform_value(attr, entry[attr]) attr_mapping[attr] = value yield table_name, attr_mapping def create_subtables(self): for table_name, columns in self.sub_tables.items(): for column_name, column_type in columns: subtable_name = f"{table_name}_{column_name}" self.client.create_subtable( table_name, subtable_name, { "doc_id": "STRING(64)", "dict_doc_id": "STRING(64)", column_name: column_type, }, "doc_id", "dict_doc_id", ) index_name = f"{subtable_name}Idx" query = f"CREATE INDEX {self.client.quoted_id(index_name)} ON {self.client.quoted_id(subtable_name)} ({self.client.quoted_id(column_name)})" self.client.create_index(query) def column_in_subtable(self, table_name, column): exists = False # column_mapping is a list column_mapping = self.sub_tables.get(table_name, []) for cm in column_mapping: if column == cm[0]: exists = True break return exists def insert_into_subtable(self, table_name, column_mapping): for column, value in column_mapping.items(): if not self.column_in_subtable(table_name, column): continue for item in value: hashed = hashlib.sha256() hashed.update(item.encode()) dict_doc_id = hashed.digest().hex() self.client.insert_into( f"{table_name}_{column}", { "doc_id": column_mapping["doc_id"], "dict_doc_id": dict_doc_id, column: item }, ) def update_schema(self): table_mapping = self.client.get_table_mapping() # 1 - jansDefAcrValues is changed to multivalued (JSON type) table_name = "jansClnt" col_name = "jansDefAcrValues" old_data_type = table_mapping[table_name][col_name] data_type = self.get_data_type(col_name, table_name) if not old_data_type.startswith("ARRAY"): # get the value first before updating column type acr_values = { row["doc_id"]: row[col_name] for row in self.client.search(table_name, ["doc_id", col_name]) } # to change the storage format of a JSON column, drop the column and # add the column back specifying the new storage format self.client.database.update_ddl([ f"ALTER TABLE {self.client.quoted_id(table_name)} DROP COLUMN {self.client.quoted_id(col_name)}" ]) self.client.database.update_ddl([ f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}" ]) for doc_id, value in acr_values.items(): if not value: value_list = [] else: value_list = [value] self.client.update( table_name, doc_id, {col_name: self.transform_value(col_name, value_list)}) # 2 - jansUsrDN column must be in jansToken table table_name = "jansToken" col_name = "jansUsrDN" if col_name not in table_mapping[table_name]: data_type = self.get_data_type(col_name, table_name) self.client.database.update_ddl([ f"ALTER TABLE {self.client.quoted_id(table_name)} ADD COLUMN {self.client.quoted_id(col_name)} {data_type}" ])
def __init__(self, manager): super().__init__() self.manager = manager self.client = SpannerClient() self.type = "spanner"
class SpannerBackend(BaseBackend): def __init__(self, manager): super().__init__() self.manager = manager self.client = SpannerClient() self.type = "spanner" def get_entry(self, key, filter_="", attrs=None, **kwargs): table_name = kwargs.get("table_name") entry = self.client.get(table_name, key, attrs) if not entry: return None return Entry(key, entry) def modify_entry(self, key, attrs=None, **kwargs): attrs = attrs or {} table_name = kwargs.get("table_name") return self.client.update(table_name, key, attrs), "" def update_people_entries(self): # add jansAdminUIRole to default admin user admin_inum = self.manager.config.get("admin_inum") id_ = doc_id_from_dn(f"inum={admin_inum},ou=people,o=jans") kwargs = {"table_name": "jansPerson"} entry = self.get_entry(id_, **kwargs) if not entry: return # sql entry may have empty jansAdminUIRole hash ({"v": []}) if not entry.attrs["jansAdminUIRole"]: entry.attrs["jansAdminUIRole"] = ["api-admin"] self.modify_entry(id_, entry.attrs, **kwargs) def update_scopes_entries(self): # add jansAdminUIRole claim to profile scope id_ = doc_id_from_dn(self.jans_admin_ui_role_id) kwargs = {"table_name": "jansScope"} entry = self.get_entry(id_, **kwargs) if not entry: return if self.jans_admin_ui_claim not in entry.attrs["jansClaim"]: entry.attrs["jansClaim"].append(self.jans_admin_ui_claim) self.modify_entry(id_, entry.attrs, **kwargs) def update_clients_entries(self): jca_client_id = self.manager.config.get("jca_client_id") id_ = doc_id_from_dn(f"inum={jca_client_id},ou=clients,o=jans") kwargs = {"table_name": "jansClnt"} entry = self.get_entry(id_, **kwargs) if not entry: return should_update = False # modify redirect UI of config-api client hostname = self.manager.config.get("hostname") if f"https://{hostname}/admin" not in entry.attrs["jansRedirectURI"]: entry.attrs["jansRedirectURI"].append(f"https://{hostname}/admin") should_update = True # add jans_stat, SCIM users.read, SCIM users.write scopes to config-api client for scope in (self.jans_scim_scopes + self.jans_stat_scopes): if scope not in entry.attrs["jansScope"]: entry.attrs["jansScope"].append(scope) should_update = True if should_update: self.modify_entry(id_, entry.attrs, **kwargs) def update_scim_scopes_entries(self): # add jansAttrs to SCIM users.read and users.write scopes ids = [doc_id_from_dn(scope) for scope in self.jans_scim_scopes] kwargs = {"table_name": "jansScope"} for id_ in ids: entry = self.get_entry(id_, **kwargs) if not entry: continue if "jansAttrs" not in entry.attrs: entry.attrs["jansAttrs"] = self.jans_attrs self.modify_entry(id_, entry.attrs, **kwargs) def update_base_entries(self): # add jansManagerGrp to base entry id_ = doc_id_from_dn(JANS_BASE_ID) kwargs = {"table_name": "jansOrganization"} entry = self.get_entry(id_, **kwargs) if not entry: return if not entry.attrs.get("jansManagerGrp"): entry.attrs["jansManagerGrp"] = JANS_MANAGER_GROUP self.modify_entry(id_, entry.attrs, **kwargs)