def test_compare_get_columns_for_sql_and_odbc(self, schema, table, engine_name): with self.engine_map[engine_name].begin() as c: dialect = Inspector(c).dialect if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) columns_odbc = dialect.get_columns(connection=c, table_name=table, schema=schema) assert str(columns_fallback) == str( columns_odbc) # object equality doesn't work for sqltypes
def test_compare_get_columns_none_table_for_sql_and_odbc( self, schema, engine_name): with self.engine_map[engine_name].begin() as c: if schema is None: c.execute("OPEN SCHEMA %s" % self.schema) dialect = Inspector(c).dialect table = None columns_fallback = dialect.get_columns(connection=c, table_name=table, schema=schema, use_sql_fallback=True) columns_odbc = dialect.get_columns(connection=c, table_name=table, schema=schema) assert str(columns_fallback) == str(columns_fallback)
def loop_tables( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: for table in inspector.get_table_names(schema): schema, table = self.standardize_schema_table_names(schema=schema, entity=table) dataset_name = self.get_identifier(schema=schema, entity=table, inspector=inspector) self.report.report_entity_scanned(dataset_name, ent_type="table") if not sql_config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment( table, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} else: description = table_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) # TODO: capture inspector.get_pk_constraint # TODO: capture inspector.get_sorted_table_and_fkc_names dataset_snapshot = DatasetSnapshot( urn= f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{dataset_name},{self.config.env})", aspects=[], ) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) schema_metadata = get_schema_metadata(self.report, dataset_name, self.platform, columns) dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu
def test_get_columns_table_name_none(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name=None, use_sql_fallback=use_sql_fallback) assert columns == []
def _check_schema_consistency(config, db_name, schema_name, parsed_schema, schema_version, schema_description, target_engine): # Connect to the server that has database that is being added from sqlalchemy.engine.reflection import Inspector inspector = Inspector(target_engine) if schema_name not in inspector.get_schema_names(): config.log.error("Schema '%s' not found.", db_name) raise MetaBException(MetaBException.DB_DOES_NOT_EXIST, db_name) db_tables = inspector.get_table_names(schema=schema_name) for table_name, parsed_table in parsed_schema.items(): # Check parsed tables - we allow other tables in schema if table_name not in db_tables: config.log.error( "Table '%s' not found in db, present in ascii file.", table_name) raise MetaBException(MetaBException.TB_NOT_IN_DB, table_name) db_columns = inspector.get_columns(table_name=table_name, schema=schema_name) parsed_columns = parsed_table["columns"] if len(parsed_columns) != len(db_columns): config.log.error( "Number of columns in db for table %s (%d) " "differs from number columns in schema (%d)", table_name, len(db_columns), len(parsed_columns)) raise MetaBException(MetaBException.NOT_MATCHING) for column in parsed_columns: column_name = column["name"] if column_name not in db_columns: config.log.error( "Column '%s.%s' not found in db, " "but exists in schema DDL", table_name, column_name) raise MetaBException(MetaBException.COL_NOT_IN_TB, column_name, table_name) # Get schema description and version, it is ok if it is missing ret = target_engine.execute( "SELECT version, descr FROM %s.ZZZ_Schema_Description" % db_name) if ret.rowcount != 1: config.log.error("Db '%s' does not contain schema version/description", db_name) else: (found_schema_version, found_schema_description) = ret.first() if found_schema_version != schema_version or \ found_schema_description != schema_description: raise MetaBException( MetaBException.NOT_MATCHING, "Schema name or description does not match defined values.")
def get_columns(cls, inspector: Inspector, table_name: str, schema: Optional[str]) -> List[Dict[str, Any]]: """ Get all columns from a given schema and table :param inspector: SqlAlchemy Inspector instance :param table_name: Table name :param schema: Schema name. If omitted, uses default schema for database :return: All columns in table """ return inspector.get_columns(table_name, schema)
def _get_columns(self, dataset_name: str, inspector: Inspector, schema: str, table: str) -> List[dict]: columns = [] try: columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") except Exception as e: self.report.report_warning( dataset_name, f"unable to get column information due to an error -> {e}", ) return columns
def _test_get_columns(self, schema=None, table_type='table'): meta = MetaData(testing.db) (users, addresses) = createTables(meta, schema) table_names = ['users', 'email_addresses'] meta.create_all() if table_type == 'view': createViews(meta.bind, schema) table_names = ['users_v', 'email_addresses_v'] try: insp = Inspector(meta.bind) for (table_name, table) in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) # should be in order for (i, col) in enumerate(table.columns): eq_(col.name, cols[i]['name']) ctype = cols[i]['type'].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ # Oracle returns Date for DateTime. if testing.against('oracle') \ and ctype_def in (sql_types.Date, sql_types.DateTime): ctype_def = sql_types.Date # assert that the desired type and return type # share a base within one of the generic types. self.assert_( len( set( ctype.__mro__ ).intersection(ctype_def.__mro__) .intersection([sql_types.Integer, sql_types.Numeric, sql_types.DateTime, sql_types.Date, sql_types.Time, sql_types.String, sql_types._Binary]) ) > 0 ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], ctype))) finally: if table_type == 'view': dropViews(meta.bind, schema) addresses.drop() users.drop()
def _test_get_columns(self, schema=None, table_type='table'): meta = MetaData(testing.db) (users, addresses) = createTables(meta, schema) table_names = ['users', 'email_addresses'] meta.create_all() if table_type == 'view': createViews(meta.bind, schema) table_names = ['users_v', 'email_addresses_v'] try: insp = Inspector(meta.bind) for (table_name, table) in zip(table_names, (users, addresses)): schema_name = schema if schema and testing.against('oracle'): schema_name = schema.upper() cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) # should be in order for (i, col) in enumerate(table.columns): self.assertEqual(col.name, cols[i]['name']) # coltype is tricky # It may not inherit from col.type while they share # the same base. ctype = cols[i]['type'].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ # Oracle returns Date for DateTime. if testing.against('oracle') \ and ctype_def in (sql_types.Date, sql_types.DateTime): ctype_def = sql_types.Date self.assert_( issubclass(ctype, ctype_def) or \ len( set( ctype.__bases__ ).intersection(ctype_def.__bases__)) > 0 ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], ctype))) finally: if table_type == 'view': dropViews(meta.bind, schema) addresses.drop() users.drop()
def test_get_columns(self, use_sql_fallback, engine_name): with self.engine_map[engine_name].begin() as c: dialect = Inspector(c).dialect columns = dialect.get_columns(connection=c, schema=self.schema, table_name="t", use_sql_fallback=use_sql_fallback) expected = [ { 'default': None, 'is_distribution_key': False, 'name': 'pid1', 'nullable': False, 'type': INTEGER() }, { 'default': None, 'is_distribution_key': False, 'name': 'pid2', 'nullable': False, 'type': INTEGER() }, { 'default': None, 'is_distribution_key': False, 'name': 'name', 'nullable': True, 'type': VARCHAR(length=20) }, { 'default': None, 'is_distribution_key': False, 'name': 'age', 'nullable': True, 'type': INTEGER() }, ] assert self.make_columns_comparable( expected) == self.make_columns_comparable(columns)
def loop_views( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: for view in inspector.get_view_names(schema): schema, view = sql_config.standardize_schema_table_names( schema, view) dataset_name = sql_config.get_identifier(schema, view) self.report.report_entity_scanned(dataset_name, ent_type="view") if not sql_config.view_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue try: columns = inspector.get_columns(view, schema) except KeyError: # For certain types of views, we are unable to fetch the list of columns. self.report.report_warning( dataset_name, "unable to get schema for this view") schema_metadata = None else: schema_metadata = get_schema_metadata(self.report, dataset_name, self.platform, columns) try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. view_info: dict = inspector.get_table_comment( view, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} else: description = view_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = view_info.get("properties", {}) try: view_definition = inspector.get_view_definition(view, schema) if view_definition is None: view_definition = "" else: # Some dialects return a TextClause instead of a raw string, # so we need to convert them to a string. view_definition = str(view_definition) except NotImplementedError: view_definition = "" properties["view_definition"] = view_definition properties["is_view"] = "True" dataset_snapshot = DatasetSnapshot( urn= f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{dataset_name},{self.config.env})", aspects=[], ) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, # uri=dataset_name, ) dataset_snapshot.aspects.append(dataset_properties) if schema_metadata: dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu
def loop_tables( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: tables_seen: Set[str] = set() for table in inspector.get_table_names(schema): schema, table = self.standardize_schema_table_names(schema=schema, entity=table) dataset_name = self.get_identifier(schema=schema, entity=table, inspector=inspector) if dataset_name not in tables_seen: tables_seen.add(dataset_name) else: logger.debug( f"{dataset_name} has already been seen, skipping...") continue self.report.report_entity_scanned(dataset_name, ent_type="table") if not sql_config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue columns = [] try: columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") except Exception as e: self.report.report_warning( dataset_name, f"unable to get column information due to an error -> {e}", ) try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment( table, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} except ProgrammingError as pe: # Snowflake needs schema names quoted when fetching table comments. logger.debug( f"Encountered ProgrammingError. Retrying with quoted schema name for schema {schema} and table {table}", pe, ) description = None properties = {} table_info: dict = inspector.get_table_comment( table, f'"{schema}"') # type: ignore else: description = table_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) dataset_urn = make_dataset_urn(self.platform, dataset_name, self.config.env) dataset_snapshot = DatasetSnapshot( urn=dataset_urn, aspects=[StatusClass(removed=False)], ) if self.is_stateful_ingestion_configured(): cur_checkpoint = self.get_current_checkpoint( self.get_default_ingestion_job_id()) if cur_checkpoint is not None: checkpoint_state = cast(BaseSQLAlchemyCheckpointState, cur_checkpoint.state) checkpoint_state.add_table_urn(dataset_urn) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) pk_constraints: dict = inspector.get_pk_constraint(table, schema) try: foreign_keys = [ self.get_foreign_key_metadata(dataset_urn, schema, fk_rec, inspector) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( f"{dataset_urn}: failure in foreign key extraction... skipping" ) foreign_keys = [] schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints) schema_metadata = get_schema_metadata( self.report, dataset_name, self.platform, columns, pk_constraints, foreign_keys, schema_fields, ) dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu
def _process_view( self, dataset_name: str, inspector: Inspector, schema: str, view: str, sql_config: SQLAlchemyConfig, ) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]: try: columns = inspector.get_columns(view, schema) except KeyError: # For certain types of views, we are unable to fetch the list of columns. self.report.report_warning(dataset_name, "unable to get schema for this view") schema_metadata = None else: schema_fields = self.get_schema_fields(dataset_name, columns) schema_metadata = get_schema_metadata( self.report, dataset_name, self.platform, columns, canonical_schema=schema_fields, ) try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. view_info: dict = inspector.get_table_comment( view, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} else: description = view_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = view_info.get("properties", {}) try: view_definition = inspector.get_view_definition(view, schema) if view_definition is None: view_definition = "" else: # Some dialects return a TextClause instead of a raw string, # so we need to convert them to a string. view_definition = str(view_definition) except NotImplementedError: view_definition = "" properties["view_definition"] = view_definition properties["is_view"] = "True" dataset_urn = make_dataset_urn_with_platform_instance( self.platform, dataset_name, self.config.platform_instance, self.config.env, ) dataset_snapshot = DatasetSnapshot( urn=dataset_urn, aspects=[StatusClass(removed=False)], ) db_name = self.get_db_name(inspector) yield from self.add_table_to_schema_container(dataset_urn, db_name, schema) if self.is_stateful_ingestion_configured(): cur_checkpoint = self.get_current_checkpoint( self.get_default_ingestion_job_id()) if cur_checkpoint is not None: checkpoint_state = cast(BaseSQLAlchemyCheckpointState, cur_checkpoint.state) checkpoint_state.add_view_urn(dataset_urn) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, # uri=dataset_name, ) dataset_snapshot.aspects.append(dataset_properties) if schema_metadata: dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu dpi_aspect = self.get_dataplatform_instance_aspect( dataset_urn=dataset_urn) if dpi_aspect: yield dpi_aspect yield from self._get_domain_wu( dataset_name=dataset_name, entity_urn=dataset_urn, entity_type="dataset", sql_config=sql_config, )
class Db(object): def __init__(self, sqla_conn, args, schemas=None): if schemas is None: schemas = [None] self.args = args self.sqla_conn = sqla_conn self.schemas = schemas self.engine = sa.create_engine(sqla_conn) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for schema in self.schemas: meta = sa.MetaData( bind=self.engine) # excised schema=schema to prevent errors meta.reflect(schema=schema) for tbl in meta.sorted_tables: if args.tables and not _table_matches_any_pattern( tbl.schema, tbl.name, self.args.tables): continue if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.exclude_tables): continue tbl.db = self if self.engine.name == 'postgresql': fix_postgres_array_of_enum(self.conn, tbl) # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) if not tbl.pk: tbl.pk = [ d['name'] for d in self.inspector.get_columns(tbl.name, schema=tbl.schema) ] tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.child_fks = [] estimate_rows = not _table_matches_any_pattern( tbl.schema, tbl.name, self.args.full_tables) tbl.find_n_rows(estimate=estimate_rows) self.tables[(tbl.schema, tbl.name)] = tbl all_constraints = args.config.get('constraints', {}) for ((tbl_schema, tbl_name), tbl) in self.tables.items(): qualified = "{}.{}".format(tbl_schema, tbl_name) if qualified in all_constraints: constraints = all_constraints[qualified] else: constraints = all_constraints.get(tbl_name, []) tbl.constraints = constraints for fk in (tbl.fks + constraints): fk['constrained_schema'] = tbl_schema fk['constrained_table'] = tbl_name # TODO: check against constrained_table self.tables[(fk['referred_schema'], fk['referred_table'])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() target.pending = dict() target.done = set() target.fetch_all = False if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables): target.n_rows_desired = tbl.n_rows target.fetch_all = True else: if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int( math.pow( 10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int( tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType( _completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append("Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or '', tbl_name)) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == 'y') def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug('create_row_in %s:%s ' % (target.name, target.pk_val(source_row))) pks = hashable((source_row[key] for key in target.pk)) row_exists = pks in target.pending or pks in target.done logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk['referred_schema'], fk['referred_table'])] slct = sa.sql.select([ target_parent, ]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk['referred_columns'], fk['constrained_columns']): slct = slct.where( target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) pks = hashable((source_row[key] for key in target.pk)) target.n_rows += 1 # insert source row here to prevent recursion if self.args.buffer == 0: target_db.insert_one(target, pks, source_row) else: target.pending[pks] = source_row # make sure that all referenced rows are in referenced table(s) for constraint in target.constraints: target_referred = target_db.tables[( constraint['referred_schema'], constraint['referred_table'])] slct = sa.sql.select([ target_referred, ]) any_non_null_key_columns = False for (referred_col, constrained_col) in zip( constraint['referred_columns'], constraint['constrained_columns']): slct = slct.where(target_referred.c[referred_col] == source_row[constrained_col]) if source_row[constrained_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_referred_row = target_db.conn.execute(slct).first() if not target_referred_row: source_referred_row = self.conn.execute(slct).first() # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist if source_referred_row: self.create_row_in(source_referred_row, target_db, target_referred) # defer signal? signal(SIGNAL_ROW_ADDED).send(self, source_row=source_row, target_db=target_db, target_table=target, prioritized=prioritized) for child_fk in target.child_fks: child = self.tables[(child_fk['constrained_schema'], child_fk['constrained_table'])] slct = sa.sql.select([child]) for (child_col, this_col) in zip(child_fk['constrained_columns'], child_fk['referred_columns']): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row) in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif n == 0: child.target.requested.appendleft( (desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) @property def pending(self): return functools.reduce( lambda count, table: count + len(table.pending), self.tables.values(), 0) def insert_one(self, table, pk, values): self.conn.execute(table.insert(), values) table.done.add(pk) def flush(self): for table in self.tables.values(): if not table.pending: continue self.conn.execute(table.insert(), list(table.pending.values())) table.done = table.done.union(table.pending.keys()) table.pending = dict() def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if '.' in tbl_name: (tbl_schema, tbl_name) = tbl_name.split('.', 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables break logging.debug("total n_rows in target: %d" % sum( (t.n_rows for t in target_db.tables.values()))) logging.debug( "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows)) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: break (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized) if target_db.pending > self.args.buffer > 0: target_db.flush() if self.args.buffer > 0: target_db.flush()
try: manager.run() except SystemExit, e: self.assertEquals(e.code, 0) assert self.dbmigrate._get_db_version() == \ self.dbmigrate._get_repo_version() i = Inspector(self.dbmigrate.db.engine) # check if table "test" exist assert 'test' in i.get_table_names() # check if column "column2" exists in table "test" assert 'column2' in [c['name'] for c in i.get_columns('test')] @with_database def test_migrate_downgrade_to_0(self): manager = Manager(self.app) manager.add_command('dbmigrate', dbmanager) sys.argv = ['manage.py', 'dbmigrate', 'migrate', '-v', '0'] try: manager.run() except SystemExit, e: self.assertEquals(e.code, 0) i = Inspector(self.dbmigrate.db.engine)
def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: return inspector.get_columns(table_name, schema)
def verify_thd(conn): metadata = sa.MetaData() metadata.bind = conn # Verify database contents. # 'workers' table contents. workers = sautils.Table('workers', metadata, autoload=True) c = workers.c q = sa.select([c.id, c.name, c.info]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (30, u'worker-1', u'{}'), (31, u'worker-2', u'{"a": 1}'), ]) # 'builds' table contents. builds = sautils.Table('builds', metadata, autoload=True) c = builds.c q = sa.select([ c.id, c.number, c.builderid, c.buildrequestid, c.workerid, c.masterid, c.started_at, c.complete_at, c.state_string, c.results ]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (40, 1, None, 20, 30, 10, 1000, None, u'state', None), (41, 2, 50, 21, None, 11, 2000, 3000, u'state 2', 9), ]) # 'configured_workers' table contents. configured_workers = sautils.Table('configured_workers', metadata, autoload=True) c = configured_workers.c q = sa.select([c.id, c.buildermasterid, c.workerid]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (60, 70, 30), (61, 71, 31), ]) # 'connected_workers' table contents. connected_workers = sautils.Table('connected_workers', metadata, autoload=True) c = connected_workers.c q = sa.select([c.id, c.masterid, c.workerid]).order_by(c.id) self.assertEqual(q.execute().fetchall(), [ (80, 10, 30), (81, 11, 31), ]) # Verify that there is no "slave"-named items in schema. inspector = Inspector(conn) def check_name(name, table_name, item_type): if not name: return self.assertTrue( u"slave" not in name.lower(), msg=u"'slave'-named {type} in table '{table}': " u"'{name}'".format(type=item_type, table=table_name, name=name)) # Check every table. for table_name in inspector.get_table_names(): # Check table name. check_name(table_name, table_name, u"table name") # Check column names. for column_info in inspector.get_columns(table_name): check_name(column_info['name'], table_name, u"column") # Check foreign key names. for fk_info in inspector.get_foreign_keys(table_name): check_name(fk_info['name'], table_name, u"foreign key") # Check indexes names. for index_info in inspector.get_indexes(table_name): check_name(index_info['name'], table_name, u"index") # Check primary keys constraints names. pk_info = inspector.get_pk_constraint(table_name) check_name(pk_info.get('name'), table_name, u"primary key") # Test that no "slave"-named items present in schema for name in inspector.get_schema_names(): self.assertTrue(u"slave" not in name.lower())
class Db(object): def __init__(self, sqla_conn, args, schemas=[None]): self.args = args self.sqla_conn = sqla_conn self.schemas = schemas self.engine = sa.create_engine(sqla_conn) self.inspector = Inspector(bind=self.engine) self.conn = self.engine.connect() self.tables = OrderedDict() for schema in self.schemas: meta = sa.MetaData(bind=self.engine) # excised schema=schema to prevent errors meta.reflect(schema=schema) for tbl in meta.sorted_tables: if args.tables and not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.tables): continue if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.exclude_tables): continue tbl.db = self # TODO: Replace all these monkeypatches with an instance assigment tbl.find_n_rows = types.MethodType(_find_n_rows, tbl) tbl.random_row_func = types.MethodType(_random_row_func, tbl) tbl.fks = self.inspector.get_foreign_keys(tbl.name, schema=tbl.schema) tbl.pk = self.inspector.get_primary_keys(tbl.name, schema=tbl.schema) if not tbl.pk: tbl.pk = [d["name"] for d in self.inspector.get_columns(tbl.name)] tbl.filtered_by = types.MethodType(_filtered_by, tbl) tbl.by_pk = types.MethodType(_by_pk, tbl) tbl.pk_val = types.MethodType(_pk_val, tbl) tbl.child_fks = [] estimate_rows = not _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables) tbl.find_n_rows(estimate=estimate_rows) self.tables[(tbl.schema, tbl.name)] = tbl all_constraints = args.config.get("constraints", {}) for ((tbl_schema, tbl_name), tbl) in self.tables.items(): qualified = "{}.{}".format(tbl_schema, tbl_name) if qualified in all_constraints: constraints = all_constraints[qualified] else: constraints = all_constraints.get(tbl_name, []) tbl.constraints = constraints for fk in tbl.fks + constraints: fk["constrained_schema"] = tbl_schema fk["constrained_table"] = tbl_name # TODO: check against constrained_table self.tables[(fk["referred_schema"], fk["referred_table"])].child_fks.append(fk) def __repr__(self): return "Db('%s')" % self.sqla_conn def assign_target(self, target_db): for ((tbl_schema, tbl_name), tbl) in self.tables.items(): tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl) tbl.random_rows = tbl._random_row_gen_fn() tbl.next_row = types.MethodType(_next_row, tbl) target = target_db.tables[(tbl_schema, tbl_name)] target.requested = deque() target.required = deque() target.pending = dict() target.done = set() target.fetch_all = False if _table_matches_any_pattern(tbl.schema, tbl.name, self.args.full_tables): target.n_rows_desired = tbl.n_rows target.fetch_all = True else: if tbl.n_rows: if self.args.logarithmic: target.n_rows_desired = int(math.pow(10, math.log10(tbl.n_rows) * self.args.fraction)) or 1 else: target.n_rows_desired = int(tbl.n_rows * self.args.fraction) or 1 else: target.n_rows_desired = 0 target.source = tbl tbl.target = target target.completeness_score = types.MethodType(_completeness_score, target) logging.debug("assigned methods to %s" % target.name) def confirm(self): message = [] for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]): tbl = self.tables[(tbl_schema, tbl_name)] message.append( "Create %d rows from %d in %s.%s" % (tbl.target.n_rows_desired, tbl.n_rows, tbl_schema or "", tbl_name) ) print("\n".join(sorted(message))) if self.args.yes: return True response = input("Proceed? (Y/n) ").strip().lower() return (not response) or (response[0] == "y") def create_row_in(self, source_row, target_db, target, prioritized=False): logging.debug("create_row_in %s:%s " % (target.name, target.pk_val(source_row))) pks = tuple((source_row[key] for key in target.pk)) row_exists = pks in target.pending or pks in target.done logging.debug("Row exists? %s" % str(row_exists)) if row_exists and not prioritized: return if not row_exists: # make sure that all required rows are in parent table(s) for fk in target.fks: target_parent = target_db.tables[(fk["referred_schema"], fk["referred_table"])] slct = sa.sql.select([target_parent]) any_non_null_key_columns = False for (parent_col, child_col) in zip(fk["referred_columns"], fk["constrained_columns"]): slct = slct.where(target_parent.c[parent_col] == source_row[child_col]) if source_row[child_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_parent_row = target_db.conn.execute(slct).first() if not target_parent_row: source_parent_row = self.conn.execute(slct).first() self.create_row_in(source_parent_row, target_db, target_parent) # make sure that all referenced rows are in referenced table(s) for constraint in target.constraints: target_referred = target_db.tables[(constraint["referred_schema"], constraint["referred_table"])] slct = sa.sql.select([target_referred]) any_non_null_key_columns = False for (referred_col, constrained_col) in zip( constraint["referred_columns"], constraint["constrained_columns"] ): slct = slct.where(target_referred.c[referred_col] == source_row[constrained_col]) if source_row[constrained_col] is not None: any_non_null_key_columns = True break if any_non_null_key_columns: target_referred_row = target_db.conn.execute(slct).first() if not target_referred_row: source_referred_row = self.conn.execute(slct).first() # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist if source_referred_row: self.create_row_in(source_referred_row, target_db, target_referred) pks = tuple((source_row[key] for key in target.pk)) target.pending[pks] = source_row target.n_rows += 1 for child_fk in target.child_fks: child = self.tables[(child_fk["constrained_schema"], child_fk["constrained_table"])] slct = sa.sql.select([child]) for (child_col, this_col) in zip(child_fk["constrained_columns"], child_fk["referred_columns"]): slct = slct.where(child.c[child_col] == source_row[this_col]) if not prioritized: slct = slct.limit(self.args.children) for (n, desired_row) in enumerate(self.conn.execute(slct)): if prioritized: child.target.required.append((desired_row, prioritized)) elif n == 0: child.target.requested.appendleft((desired_row, prioritized)) else: child.target.requested.append((desired_row, prioritized)) @property def pending(self): return functools.reduce(lambda count, table: count + len(table.pending), self.tables.values(), 0) def flush(self): for table in self.tables.values(): if not table.pending: continue self.conn.execute(table.insert(), list(table.pending.values())) table.done = table.done.union(table.pending.keys()) table.pending = dict() def create_subset_in(self, target_db): for (tbl_name, pks) in self.args.force_rows.items(): if "." in tbl_name: (tbl_schema, tbl_name) = tbl_name.split(".", 1) else: tbl_schema = None source = self.tables[(tbl_schema, tbl_name)] for pk in pks: source_row = source.by_pk(pk) if source_row: self.create_row_in(source_row, target_db, source.target, prioritized=True) else: logging.warn("requested %s:%s not found in source db," "could not create" % (source.name, pk)) # import pdb; pdb.set_trace() while True: targets = sorted(target_db.tables.values(), key=lambda t: t.completeness_score()) try: target = targets.pop(0) while not target.source.n_rows: target = targets.pop(0) except IndexError: # pop failure, no more tables break logging.debug("total n_rows in target: %d" % sum((t.n_rows for t in target_db.tables.values()))) logging.debug( "target tables with 0 n_rows: %s" % ", ".join(t.name for t in target_db.tables.values() if not t.n_rows) ) logging.info("lowest completeness score (in %s) at %f" % (target.name, target.completeness_score())) if target.completeness_score() > 0.97: break (source_row, prioritized) = target.source.next_row() self.create_row_in(source_row, target_db, target, prioritized=prioritized) if target_db.pending > self.args.buffer: target_db.flush() target_db.flush()
def verify_thd(conn): metadata = sa.MetaData() metadata.bind = conn # Verify database contents. # 'workers' table contents. workers = sautils.Table('workers', metadata, autoload=True) c = workers.c q = sa.select( [c.id, c.name, c.info] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (30, 'worker-1', '{}'), (31, 'worker-2', '{"a": 1}'), ]) # 'builds' table contents. builds = sautils.Table('builds', metadata, autoload=True) c = builds.c q = sa.select( [c.id, c.number, c.builderid, c.buildrequestid, c.workerid, c.masterid, c.started_at, c.complete_at, c.state_string, c.results] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (40, 1, None, 20, 30, 10, 1000, None, 'state', None), (41, 2, 50, 21, None, 11, 2000, 3000, 'state 2', 9), ]) # 'configured_workers' table contents. configured_workers = sautils.Table( 'configured_workers', metadata, autoload=True) c = configured_workers.c q = sa.select( [c.id, c.buildermasterid, c.workerid] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (60, 70, 30), (61, 71, 31), ]) # 'connected_workers' table contents. connected_workers = sautils.Table( 'connected_workers', metadata, autoload=True) c = connected_workers.c q = sa.select( [c.id, c.masterid, c.workerid] ).order_by(c.id) self.assertEqual( q.execute().fetchall(), [ (80, 10, 30), (81, 11, 31), ]) # Verify that there is no "slave"-named items in schema. inspector = Inspector(conn) def check_name(name, table_name, item_type): if not name: return self.assertTrue( "slave" not in name.lower(), msg="'slave'-named {type} in table '{table}': " "'{name}'".format( type=item_type, table=table_name, name=name)) # Check every table. for table_name in inspector.get_table_names(): # Check table name. check_name(table_name, table_name, "table name") # Check column names. for column_info in inspector.get_columns(table_name): check_name(column_info['name'], table_name, "column") # Check foreign key names. for fk_info in inspector.get_foreign_keys(table_name): check_name(fk_info['name'], table_name, "foreign key") # Check indexes names. for index_info in inspector.get_indexes(table_name): check_name(index_info['name'], table_name, "index") # Check primary keys constraints names. pk_info = inspector.get_pk_constraint(table_name) check_name(pk_info.get('name'), table_name, "primary key") # Test that no "slave"-named items present in schema for name in inspector.get_schema_names(): self.assertTrue("slave" not in name.lower())
def get_columns(cls, inspector: Inspector, table_name: str, schema: Optional[str]) -> List[Dict[str, Any]]: return inspector.get_columns(table_name, schema)
def loop_tables( self, inspector: Inspector, schema: str, sql_config: SQLAlchemyConfig, ) -> Iterable[SqlWorkUnit]: for table in inspector.get_table_names(schema): schema, table = self.standardize_schema_table_names(schema=schema, entity=table) dataset_name = self.get_identifier(schema=schema, entity=table, inspector=inspector) self.report.report_entity_scanned(dataset_name, ent_type="table") if not sql_config.table_pattern.allowed(dataset_name): self.report.report_dropped(dataset_name) continue columns = inspector.get_columns(table, schema) if len(columns) == 0: self.report.report_warning(dataset_name, "missing column information") try: # SQLALchemy stubs are incomplete and missing this method. # PR: https://github.com/dropbox/sqlalchemy-stubs/pull/223. table_info: dict = inspector.get_table_comment( table, schema) # type: ignore except NotImplementedError: description: Optional[str] = None properties: Dict[str, str] = {} else: description = table_info["text"] # The "properties" field is a non-standard addition to SQLAlchemy's interface. properties = table_info.get("properties", {}) datasetUrn = f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{dataset_name},{self.config.env})" dataset_snapshot = DatasetSnapshot( urn=datasetUrn, aspects=[], ) if description is not None or properties: dataset_properties = DatasetPropertiesClass( description=description, customProperties=properties, ) dataset_snapshot.aspects.append(dataset_properties) pk_constraints: dict = inspector.get_pk_constraint(table, schema) try: foreign_keys = [ self.get_foreign_key_metadata(datasetUrn, fk_rec, inspector) for fk_rec in inspector.get_foreign_keys(table, schema) ] except KeyError: # certain databases like MySQL cause issues due to lower-case/upper-case irregularities logger.debug( f"{datasetUrn}: failure in foreign key extraction... skipping" ) foreign_keys = [] schema_metadata = get_schema_metadata( self.report, dataset_name, self.platform, columns, pk_constraints, foreign_keys, ) dataset_snapshot.aspects.append(schema_metadata) mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) wu = SqlWorkUnit(id=dataset_name, mce=mce) self.report.report_workunit(wu) yield wu