def _test_get_foreign_keys(self, schema=None): meta = MetaData(testing.db) (users, addresses) = createTables(meta, schema) meta.create_all() insp = Inspector(meta.bind) try: expected_schema = schema # users users_fkeys = insp.get_foreign_keys(users.name, schema=schema) fkey1 = users_fkeys[0] self.assert_(fkey1['name'] is not None) eq_(fkey1['referred_schema'], expected_schema) eq_(fkey1['referred_table'], users.name) eq_(fkey1['referred_columns'], ['user_id', ]) eq_(fkey1['constrained_columns'], ['parent_user_id']) #addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] self.assert_(fkey1['name'] is not None) eq_(fkey1['referred_schema'], expected_schema) eq_(fkey1['referred_table'], users.name) eq_(fkey1['referred_columns'], ['user_id', ]) eq_(fkey1['constrained_columns'], ['remote_user_id']) finally: addresses.drop() users.drop()
def _test_get_foreign_keys(self, schema=None): meta = MetaData(testing.db) (users, addresses) = createTables(meta, schema) meta.create_all() insp = Inspector(meta.bind) try: expected_schema = schema if schema is None: try: expected_schema = meta.bind.dialect.get_default_schema_name( meta.bind) except NotImplementedError: expected_schema = None # users users_fkeys = insp.get_foreign_keys(users.name, schema=schema) fkey1 = users_fkeys[0] self.assert_(fkey1['name'] is not None) self.assertEqual(fkey1['referred_schema'], expected_schema) self.assertEqual(fkey1['referred_table'], users.name) self.assertEqual(fkey1['referred_columns'], ['user_id', ]) self.assertEqual(fkey1['constrained_columns'], ['parent_user_id']) #addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] self.assert_(fkey1['name'] is not None) self.assertEqual(fkey1['referred_schema'], expected_schema) self.assertEqual(fkey1['referred_table'], users.name) self.assertEqual(fkey1['referred_columns'], ['user_id', ]) self.assertEqual(fkey1['constrained_columns'], ['remote_user_id']) finally: addresses.drop() users.drop()
def test_compare_get_foreign_keys_for_sql_and_odbc(self, schema, table, engine_name): with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema_2) dialect = Inspector(c).dialect foreign_keys_fallback = dialect.get_foreign_keys( connection=c, table_name=table, schema=schema, use_sql_fallback=True) foreign_keys_odbc = dialect.get_foreign_keys(connection=c, table_name=table, schema=schema) assert str(foreign_keys_fallback) == str(foreign_keys_odbc)
def test_get_foreign_keys_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys( connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert foreign_keys == []
def generic_find_fk_constraint_name( # pylint: disable=invalid-name table: str, columns: Set[str], referenced: str, insp: Inspector) -> Optional[str]: """Utility to find a foreign-key constraint name in alembic migrations""" for fk in insp.get_foreign_keys(table): if (fk["referred_table"] == referenced and set(fk["referred_columns"]) == columns): return fk["name"] return None
def generic_find_fk_constraint_names(table: str, columns: Set[str], referenced: str, insp: Inspector) -> Set[str]: """Utility to find foreign-key constraint names in alembic migrations""" names = set() for fk in insp.get_foreign_keys(table): if (fk["referred_table"] == referenced and set(fk["referred_columns"]) == columns): names.add(fk["name"]) return names
def _get_foreign_keys(self, dataset_urn: str, inspector: Inspector, schema: str, table: str) -> List[ForeignKeyConstraint]: try: foreign_keys = [ self.get_foreign_key_metadata(dataset_urn, schema, fk_rec, inspector) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( f"{dataset_urn}: failure in foreign key extraction... skipping" ) foreign_keys = [] return foreign_keys
def test_get_foreign_keys(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: dialect = Inspector(c).dialect foreign_keys = dialect.get_foreign_keys( connection=c, schema=self.schema, table_name="s", use_sql_fallback=use_sql_fallback) expected = [{ 'name': 'fk_test', 'constrained_columns': ['fid1', 'fid2'], 'referred_schema': 'test_get_metadata_functions_schema', 'referred_table': 't', 'referred_columns': ['pid1', 'pid2'] }] assert foreign_keys == expected
class Db(object): def __init__(self, sqla_conn, args, schemas=[None]): self.args = args self.sqla_conn = sqla_conn self.schemas = schemas self.engine = sa.create_engine(sqla_conn) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for schema in self.schemas: meta = sa.MetaData(bind=self.engine) # excised schema=schema to prevent errors meta.reflect(schema=schema) for tbl in meta.sorted_tables: if args.tables and not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.tables): continue if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.exclude_tables): continue tbl.db = self # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) if not tbl.pk: tbl.pk = [d["name"] for d in self.inspector.get_columns(tbl.name)] tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.child_fks = [] estimate_rows = not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables) tbl.find_n_rows(estimate=estimate_rows) self.tables[(tbl.schema, tbl.name)] = tbl all_constraints = args.config.get("constraints", {}) for ((tbl_schema, tbl_name), tbl) in self.tables.items(): qualified = "{}.{}".format(tbl_schema, tbl_name) if qualified in all_constraints: constraints = all_constraints[qualified] else: constraints = all_constraints.get(tbl_name, []) tbl.constraints = constraints for fk in tbl.fks + constraints: fk["constrained_schema"] = tbl_schema fk["constrained_table"] = tbl_name # TODO: check against constrained_table self.tables[(fk["referred_schema"], fk["referred_table"])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() target.pending = dict() target.done = set() target.fetch_all = False if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables): target.n_rows_desired = tbl.n_rows target.fetch_all = True else: if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int(math.pow(10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int(tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType(_completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append( "Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or "", tbl_name) ) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == "y") def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug("create_row_in %s:%s " % (target.name, target.pk_val(source_row))) pks = tuple((source_row[key] for key in target.pk)) row_exists = pks in target.pending or pks in target.done logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk["referred_schema"], fk["referred_table"])] slct = sa.sql.select([target_parent]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk["referred_columns"], fk["constrained_columns"]): slct = slct.where(target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) # make sure that all referenced rows are in referenced table(s) for constraint in target.constraints: target_referred = target_db.tables[(constraint["referred_schema"], constraint["referred_table"])] slct = sa.sql.select([target_referred]) any_non_null_key_columns = False for (referred_col, constrained_col) in zip( constraint["referred_columns"], constraint["constrained_columns"] ): slct = slct.where(target_referred.c[referred_col] == source_row[constrained_col]) if source_row[constrained_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_referred_row = target_db.conn.execute(slct).first() if not target_referred_row: source_referred_row = self.conn.execute(slct).first() # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist if source_referred_row: self.create_row_in(source_referred_row, target_db, target_referred) pks = tuple((source_row[key] for key in target.pk)) target.pending[pks] = source_row target.n_rows += 1 for child_fk in target.child_fks: child = self.tables[(child_fk["constrained_schema"], child_fk["constrained_table"])] slct = sa.sql.select([child]) for (child_col, this_col) in zip(child_fk["constrained_columns"], child_fk["referred_columns"]): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row) in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif n == 0: child.target.requested.appendleft((desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) @property def pending(self): return functools.reduce(lambda count, table: count + len(table.pending), self.tables.values(), 0) def flush(self): for table in self.tables.values(): if not table.pending: continue self.conn.execute(table.insert(), list(table.pending.values())) table.done = table.done.union(table.pending.keys()) table.pending = dict() def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if "." in tbl_name: (tbl_schema, tbl_name) = tbl_name.split(".", 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) # import pdb; pdb.set_trace() while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables break logging.debug("total n_rows in target: %d" % sum((t.n_rows for t in target_db.tables.values()))) logging.debug( "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows) ) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: break (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized) if target_db.pending > self.args.buffer: target_db.flush() target_db.flush()
def verify_thd(conn): metadata = sa.MetaData() metadata.bind = conn # Verify database contents. # 'workers' table contents. workers = sautils.Table('workers', metadata, autoload=True) c = workers.c q = sa.select( [c.id, c.name, c.info] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (30, 'worker-1', '{}'), (31, 'worker-2', '{"a": 1}'), ]) # 'builds' table contents. builds = sautils.Table('builds', metadata, autoload=True) c = builds.c q = sa.select( [c.id, c.number, c.builderid, c.buildrequestid, c.workerid, c.masterid, c.started_at, c.complete_at, c.state_string, c.results] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (40, 1, None, 20, 30, 10, 1000, None, 'state', None), (41, 2, 50, 21, None, 11, 2000, 3000, 'state 2', 9), ]) # 'configured_workers' table contents. configured_workers = sautils.Table( 'configured_workers', metadata, autoload=True) c = configured_workers.c q = sa.select( [c.id, c.buildermasterid, c.workerid] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (60, 70, 30), (61, 71, 31), ]) # 'connected_workers' table contents. connected_workers = sautils.Table( 'connected_workers', metadata, autoload=True) c = connected_workers.c q = sa.select( [c.id, c.masterid, c.workerid] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (80, 10, 30), (81, 11, 31), ]) # Verify that there is no "slave"-named items in schema. inspector = Inspector(conn) def check_name(name, table_name, item_type): if not name: return self.assertTrue( "slave" not in name.lower(), msg="'slave'-named {type} in table '{table}': " "'{name}'".format( type=item_type, table=table_name, name=name)) # Check every table. for table_name in inspector.get_table_names(): # Check table name. check_name(table_name, table_name, "table name") # Check column names. for column_info in inspector.get_columns(table_name): check_name(column_info['name'], table_name, "column") # Check foreign key names. for fk_info in inspector.get_foreign_keys(table_name): check_name(fk_info['name'], table_name, "foreign key") # Check indexes names. for index_info in inspector.get_indexes(table_name): check_name(index_info['name'], table_name, "index") # Check primary keys constraints names. pk_info = inspector.get_pk_constraint(table_name) check_name(pk_info.get('name'), table_name, "primary key") # Test that no "slave"-named items present in schema for name in inspector.get_schema_names(): self.assertTrue("slave" not in name.lower())
class Db(object): def __init__(self, sqla_conn, args, schema=None): self.args = args self.sqla_conn = sqla_conn self.schema = schema self.engine = sa.create_engine(sqla_conn) self.meta = sa.MetaData( bind=self.engine) # excised schema=schema to prevent errors self.meta.reflect(schema=self.schema) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for tbl in self.meta.sorted_tables: tbl.db = self # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.exists = types.MethodType(_exists, tbl) tbl.child_fks = [] tbl.find_n_rows(estimate=True) self.tables[(tbl.schema, tbl.name)] = tbl for ((tbl_schema, tbl_name), tbl) in self.tables.items(): for fk in tbl.fks: fk['constrained_schema'] = tbl_schema fk['constrained_table'] = tbl_name # TODO: check against constrained_table self.tables[(fk['referred_schema'], fk['referred_table'])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int( math.pow( 10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int( tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType( _completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append("Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or '', tbl_name)) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == 'y') def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug('create_row_in %s:%s ' % (target.name, target.pk_val(source_row))) row_exists = target.exists(**(dict(source_row))) logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk['referred_schema'], fk['referred_table'])] slct = sa.sql.select([ target_parent, ]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk['referred_columns'], fk['constrained_columns']): slct = slct.where( target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) ins = target.insert().values(**source_row) target_db.conn.execute(ins) target.n_rows += 1 for child_fk in target.child_fks: child = self.tables[(child_fk['constrained_schema'], child_fk['constrained_table'])] slct = sa.sql.select([ child, ]) for (child_col, this_col) in zip(child_fk['constrained_columns'], child_fk['referred_columns']): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row) in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif (n == 0): child.target.requested.appendleft( (desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if '.' in tbl_name: (tbl_schema, tbl_name) = tbl_name.split('.', 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables return logging.debug("total n_rows in target: %d" % sum( (t.n_rows for t in target_db.tables.values()))) logging.debug( "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows)) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: return (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized)
def verify_thd(conn): metadata = sa.MetaData() metadata.bind = conn # Verify database contents. # 'workers' table contents. workers = sautils.Table('workers', metadata, autoload=True) c = workers.c q = sa.select([c.id, c.name, c.info]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (30, u'worker-1', u'{}'), (31, u'worker-2', u'{"a": 1}'), ]) # 'builds' table contents. builds = sautils.Table('builds', metadata, autoload=True) c = builds.c q = sa.select([ c.id, c.number, c.builderid, c.buildrequestid, c.workerid, c.masterid, c.started_at, c.complete_at, c.state_string, c.results ]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (40, 1, None, 20, 30, 10, 1000, None, u'state', None), (41, 2, 50, 21, None, 11, 2000, 3000, u'state 2', 9), ]) # 'configured_workers' table contents. configured_workers = sautils.Table('configured_workers', metadata, autoload=True) c = configured_workers.c q = sa.select([c.id, c.buildermasterid, c.workerid]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (60, 70, 30), (61, 71, 31), ]) # 'connected_workers' table contents. connected_workers = sautils.Table('connected_workers', metadata, autoload=True) c = connected_workers.c q = sa.select([c.id, c.masterid, c.workerid]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (80, 10, 30), (81, 11, 31), ]) # Verify that there is no "slave"-named items in schema. inspector = Inspector(conn) def check_name(name, table_name, item_type): if not name: return self.assertTrue( u"slave" not in name.lower(), msg=u"'slave'-named {type} in table '{table}': " u"'{name}'".format(type=item_type, table=table_name, name=name)) # Check every table. for table_name in inspector.get_table_names(): # Check table name. check_name(table_name, table_name, u"table name") # Check column names. for column_info in inspector.get_columns(table_name): check_name(column_info['name'], table_name, u"column") # Check foreign key names. for fk_info in inspector.get_foreign_keys(table_name): check_name(fk_info['name'], table_name, u"foreign key") # Check indexes names. for index_info in inspector.get_indexes(table_name): check_name(index_info['name'], table_name, u"index") # Check primary keys constraints names. pk_info = inspector.get_pk_constraint(table_name) check_name(pk_info.get('name'), table_name, u"primary key") # Test that no "slave"-named items present in schema for name in inspector.get_schema_names(): self.assertTrue(u"slave" not in name.lower())
def loop_tables( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: tables_seen: Set[str] = set() for table in inspector.get_table_names(schema): schema, table = self.standardize_schema_table_names(schema=schema, entity=table) dataset_name = self.get_identifier(schema=schema, entity=table, inspector=inspector) if dataset_name not in tables_seen: tables_seen.add(dataset_name) else: logger.debug( f"{dataset_name} has already been seen, skipping...") continue self.report.report_entity_scanned(dataset_name, ent_type="table") if not sql_config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue columns = [] try: columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") except Exception as e: self.report.report_warning( dataset_name, f"unable to get column information due to an error -> {e}", ) try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment( table, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} except ProgrammingError as pe: # Snowflake needs schema names quoted when fetching table comments. logger.debug( f"Encountered ProgrammingError. Retrying with quoted schema name for schema {schema} and table {table}", pe, ) description = None properties = {} table_info: dict = inspector.get_table_comment( table, f'"{schema}"') # type: ignore else: description = table_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) dataset_urn = make_dataset_urn(self.platform, dataset_name, self.config.env) dataset_snapshot = DatasetSnapshot( urn=dataset_urn, aspects=[StatusClass(removed=False)], ) if self.is_stateful_ingestion_configured(): cur_checkpoint = self.get_current_checkpoint( self.get_default_ingestion_job_id()) if cur_checkpoint is not None: checkpoint_state = cast(BaseSQLAlchemyCheckpointState, cur_checkpoint.state) checkpoint_state.add_table_urn(dataset_urn) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) pk_constraints: dict = inspector.get_pk_constraint(table, schema) try: foreign_keys = [ self.get_foreign_key_metadata(dataset_urn, schema, fk_rec, inspector) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( f"{dataset_urn}: failure in foreign key extraction... skipping" ) foreign_keys = [] schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints) schema_metadata = get_schema_metadata( self.report, dataset_name, self.platform, columns, pk_constraints, foreign_keys, schema_fields, ) dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu
class Db(object): def __init__(self, sqla_conn, args, schemas=None): if schemas is None: schemas = [None] self.args = args self.sqla_conn = sqla_conn self.schemas = schemas self.engine = sa.create_engine(sqla_conn) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for schema in self.schemas: meta = sa.MetaData( bind=self.engine) # excised schema=schema to prevent errors meta.reflect(schema=schema) for tbl in meta.sorted_tables: if args.tables and not _table_matches_any_pattern( tbl.schema, tbl.name, self.args.tables): continue if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.exclude_tables): continue tbl.db = self if self.engine.name == 'postgresql': fix_postgres_array_of_enum(self.conn, tbl) # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) if not tbl.pk: tbl.pk = [ d['name'] for d in self.inspector.get_columns(tbl.name, schema=tbl.schema) ] tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.child_fks = [] estimate_rows = not _table_matches_any_pattern( tbl.schema, tbl.name, self.args.full_tables) tbl.find_n_rows(estimate=estimate_rows) self.tables[(tbl.schema, tbl.name)] = tbl all_constraints = args.config.get('constraints', {}) for ((tbl_schema, tbl_name), tbl) in self.tables.items(): qualified = "{}.{}".format(tbl_schema, tbl_name) if qualified in all_constraints: constraints = all_constraints[qualified] else: constraints = all_constraints.get(tbl_name, []) tbl.constraints = constraints for fk in (tbl.fks + constraints): fk['constrained_schema'] = tbl_schema fk['constrained_table'] = tbl_name # TODO: check against constrained_table self.tables[(fk['referred_schema'], fk['referred_table'])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() target.pending = dict() target.done = set() target.fetch_all = False if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables): target.n_rows_desired = tbl.n_rows target.fetch_all = True else: if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int( math.pow( 10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int( tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType( _completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append("Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or '', tbl_name)) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == 'y') def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug('create_row_in %s:%s ' % (target.name, target.pk_val(source_row))) pks = hashable((source_row[key] for key in target.pk)) row_exists = pks in target.pending or pks in target.done logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk['referred_schema'], fk['referred_table'])] slct = sa.sql.select([ target_parent, ]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk['referred_columns'], fk['constrained_columns']): slct = slct.where( target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) pks = hashable((source_row[key] for key in target.pk)) target.n_rows += 1 # insert source row here to prevent recursion if self.args.buffer == 0: target_db.insert_one(target, pks, source_row) else: target.pending[pks] = source_row # make sure that all referenced rows are in referenced table(s) for constraint in target.constraints: target_referred = target_db.tables[( constraint['referred_schema'], constraint['referred_table'])] slct = sa.sql.select([ target_referred, ]) any_non_null_key_columns = False for (referred_col, constrained_col) in zip( constraint['referred_columns'], constraint['constrained_columns']): slct = slct.where(target_referred.c[referred_col] == source_row[constrained_col]) if source_row[constrained_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_referred_row = target_db.conn.execute(slct).first() if not target_referred_row: source_referred_row = self.conn.execute(slct).first() # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist if source_referred_row: self.create_row_in(source_referred_row, target_db, target_referred) # defer signal? signal(SIGNAL_ROW_ADDED).send(self, source_row=source_row, target_db=target_db, target_table=target, prioritized=prioritized) for child_fk in target.child_fks: child = self.tables[(child_fk['constrained_schema'], child_fk['constrained_table'])] slct = sa.sql.select([child]) for (child_col, this_col) in zip(child_fk['constrained_columns'], child_fk['referred_columns']): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row) in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif n == 0: child.target.requested.appendleft( (desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) @property def pending(self): return functools.reduce( lambda count, table: count + len(table.pending), self.tables.values(), 0) def insert_one(self, table, pk, values): self.conn.execute(table.insert(), values) table.done.add(pk) def flush(self): for table in self.tables.values(): if not table.pending: continue self.conn.execute(table.insert(), list(table.pending.values())) table.done = table.done.union(table.pending.keys()) table.pending = dict() def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if '.' in tbl_name: (tbl_schema, tbl_name) = tbl_name.split('.', 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables break logging.debug("total n_rows in target: %d" % sum( (t.n_rows for t in target_db.tables.values()))) logging.debug( "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows)) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: break (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized) if target_db.pending > self.args.buffer > 0: target_db.flush() if self.args.buffer > 0: target_db.flush()
def loop_tables( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: for table in inspector.get_table_names(schema): schema, table = self.standardize_schema_table_names(schema=schema, entity=table) dataset_name = self.get_identifier(schema=schema, entity=table, inspector=inspector) self.report.report_entity_scanned(dataset_name, ent_type="table") if not sql_config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment( table, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} else: description = table_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) datasetUrn = f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{dataset_name},{self.config.env})" dataset_snapshot = DatasetSnapshot( urn=datasetUrn, aspects=[], ) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) pk_constraints: dict = inspector.get_pk_constraint(table, schema) try: foreign_keys = [ self.get_foreign_key_metadata(datasetUrn, fk_rec, inspector) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( f"{datasetUrn}: failure in foreign key extraction... skipping" ) foreign_keys = [] schema_metadata = get_schema_metadata( self.report, dataset_name, self.platform, columns, pk_constraints, foreign_keys, ) dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu
class Db(object): def __init__(self, sqla_conn, args, schema=None): self.args = args self.sqla_conn = sqla_conn self.schema = schema self.engine = sa.create_engine(sqla_conn) self.meta = sa.MetaData(bind=self.engine) # excised schema=schema to prevent errors self.meta.reflect(schema=self.schema) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for tbl in self.meta.sorted_tables: tbl.db = self # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.exists = types.MethodType(_exists, tbl) tbl.child_fks = [] tbl.find_n_rows(estimate=True) self.tables[(tbl.schema, tbl.name)] = tbl for ((tbl_schema, tbl_name), tbl) in self.tables.items(): for fk in tbl.fks: fk['constrained_schema'] = tbl_schema fk['constrained_table'] = tbl_name # TODO: check against constrained_table self.tables[(fk['referred_schema'], fk['referred_table'])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int(math.pow(10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int(tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType(_completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append("Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or '', tbl_name)) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == 'y') def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug('create_row_in %s:%s ' % (target.name, target.pk_val(source_row))) row_exists = target.exists(**(dict(source_row))) logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk['referred_schema'], fk['referred_table'])] slct = sa.sql.select([target_parent,]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk['referred_columns'], fk['constrained_columns']): slct = slct.where(target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) ins = target.insert().values(**source_row) target_db.conn.execute(ins) target.n_rows += 1 for child_fk in target.child_fks: child = self.tables[(child_fk['constrained_schema'], child_fk['constrained_table'])] slct = sa.sql.select([child,]) for (child_col, this_col) in zip(child_fk['constrained_columns'], child_fk['referred_columns']): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row )in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif (n == 0): child.target.requested.appendleft((desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if '.' in tbl_name: (tbl_schema, tbl_name) = tbl_name.split('.', 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables return logging.debug("total n_rows in target: %d" % sum((t.n_rows for t in target_db.tables.values()))) logging.debug("target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows)) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: return (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized)