def test_extract_table_references(mocker: MockerFixture) -> None: """ Test the ``extract_table_references`` helper function. """ assert extract_table_references("SELECT 1", "trino") == set() assert extract_table_references("SELECT 1 FROM some_table", "trino") == { Table(table="some_table", schema=None, catalog=None) } assert extract_table_references( "SELECT {{ jinja }} FROM some_table", "trino") == {Table(table="some_table", schema=None, catalog=None)} assert extract_table_references( "SELECT 1 FROM some_catalog.some_schema.some_table", "trino") == { Table(table="some_table", schema="some_schema", catalog="some_catalog") } # with identifier quotes assert extract_table_references( "SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`", "mysql") == { Table(table="some_table", schema="some_schema", catalog="some_catalog") } assert extract_table_references( 'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino") == { Table(table="some_table", schema="some_schema", catalog="some_catalog") } assert extract_table_references( "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id", "trino", ) == { Table(table="some_table", schema=None, catalog=None), Table(table="other_table", schema=None, catalog=None), } # test falling back to sqlparse logger = mocker.patch("superset.sql_parse.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" assert extract_table_references( sql, "trino", ) == {Table(table="other_table", schema=None, catalog=None)} logger.warning.assert_called_once() logger = mocker.patch("superset.migrations.shared.utils.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" assert extract_table_references(sql, "trino", show_warning=False) == { Table(table="other_table", schema=None, catalog=None) } logger.warning.assert_not_called()
def postprocess_datasets(session: Session) -> None: """ Postprocess datasets after insertion to - Quote table names for physical datasets (if needed) - Link referenced tables to virtual datasets """ total = session.query(SqlaTable).count() if not total: return offset = 0 limit = 10000 joined_tables = sa.join( NewDataset, SqlaTable, NewDataset.uuid == SqlaTable.uuid, ).join( Database, Database.id == SqlaTable.database_id, isouter=True, ) assert session.query(func.count()).select_from(joined_tables).scalar() == total print(f">> Run postprocessing on {total} datasets") update_count = 0 def print_update_count(): if SHOW_PROGRESS: print( f" Will update {update_count} datasets" + " " * 20, end="\r", ) while offset < total: print( f" Process dataset {offset + 1}~{min(total, offset + limit)}..." + " " * 30 ) for ( database_id, dataset_id, expression, extra, is_physical, schema, sqlalchemy_uri, ) in session.execute( select( [ NewDataset.database_id, NewDataset.id.label("dataset_id"), NewDataset.expression, SqlaTable.extra, NewDataset.is_physical, SqlaTable.schema, Database.sqlalchemy_uri, ] ) .select_from(joined_tables) .offset(offset) .limit(limit) ): drivername = (sqlalchemy_uri or "").split("://")[0] updates = {} updated = False if is_physical and drivername and expression: quoted_expression = get_identifier_quoter(drivername)(expression) if quoted_expression != expression: updates["expression"] = quoted_expression # add schema name to `dataset.extra_json` so we don't have to join # tables in order to use datasets if schema: try: extra_json = json.loads(extra) if extra else {} except json.decoder.JSONDecodeError: extra_json = {} extra_json["schema"] = schema updates["extra_json"] = json.dumps(extra_json) if updates: session.execute( sa.update(NewDataset) .where(NewDataset.id == dataset_id) .values(**updates) ) updated = True if not is_physical and drivername and expression: table_refrences = extract_table_references( expression, get_dialect_name(drivername), show_warning=False ) found_tables = find_tables( session, database_id, schema, table_refrences ) if found_tables: op.bulk_insert( dataset_table_association_table, [ {"dataset_id": dataset_id, "table_id": table.id} for table in found_tables ], ) updated = True if updated: update_count += 1 print_update_count() session.flush() offset += limit if SHOW_PROGRESS: print("")