예제 #1
0
def test_extract_table_references(mocker: MockerFixture) -> None:
    """
    Test the ``extract_table_references`` helper function.
    """
    assert extract_table_references("SELECT 1", "trino") == set()
    assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
        Table(table="some_table", schema=None, catalog=None)
    }
    assert extract_table_references(
        "SELECT {{ jinja }} FROM some_table",
        "trino") == {Table(table="some_table", schema=None, catalog=None)}
    assert extract_table_references(
        "SELECT 1 FROM some_catalog.some_schema.some_table", "trino") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }

    # with identifier quotes
    assert extract_table_references(
        "SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`",
        "mysql") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }
    assert extract_table_references(
        'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }

    assert extract_table_references(
        "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
        "trino",
    ) == {
        Table(table="some_table", schema=None, catalog=None),
        Table(table="other_table", schema=None, catalog=None),
    }

    # test falling back to sqlparse
    logger = mocker.patch("superset.sql_parse.logger")
    sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
    assert extract_table_references(
        sql,
        "trino",
    ) == {Table(table="other_table", schema=None, catalog=None)}
    logger.warning.assert_called_once()

    logger = mocker.patch("superset.migrations.shared.utils.logger")
    sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
    assert extract_table_references(sql, "trino", show_warning=False) == {
        Table(table="other_table", schema=None, catalog=None)
    }
    logger.warning.assert_not_called()
예제 #2
0
def postprocess_datasets(session: Session) -> None:
    """
    Postprocess datasets after insertion to
      - Quote table names for physical datasets (if needed)
      - Link referenced tables to virtual datasets
    """
    total = session.query(SqlaTable).count()
    if not total:
        return

    offset = 0
    limit = 10000

    joined_tables = sa.join(
        NewDataset,
        SqlaTable,
        NewDataset.uuid == SqlaTable.uuid,
    ).join(
        Database,
        Database.id == SqlaTable.database_id,
        isouter=True,
    )
    assert session.query(func.count()).select_from(joined_tables).scalar() == total

    print(f">> Run postprocessing on {total} datasets")

    update_count = 0

    def print_update_count():
        if SHOW_PROGRESS:
            print(
                f"   Will update {update_count} datasets" + " " * 20,
                end="\r",
            )

    while offset < total:
        print(
            f"   Process dataset {offset + 1}~{min(total, offset + limit)}..."
            + " " * 30
        )
        for (
            database_id,
            dataset_id,
            expression,
            extra,
            is_physical,
            schema,
            sqlalchemy_uri,
        ) in session.execute(
            select(
                [
                    NewDataset.database_id,
                    NewDataset.id.label("dataset_id"),
                    NewDataset.expression,
                    SqlaTable.extra,
                    NewDataset.is_physical,
                    SqlaTable.schema,
                    Database.sqlalchemy_uri,
                ]
            )
            .select_from(joined_tables)
            .offset(offset)
            .limit(limit)
        ):
            drivername = (sqlalchemy_uri or "").split("://")[0]
            updates = {}
            updated = False
            if is_physical and drivername and expression:
                quoted_expression = get_identifier_quoter(drivername)(expression)
                if quoted_expression != expression:
                    updates["expression"] = quoted_expression

            # add schema name to `dataset.extra_json` so we don't have to join
            # tables in order to use datasets
            if schema:
                try:
                    extra_json = json.loads(extra) if extra else {}
                except json.decoder.JSONDecodeError:
                    extra_json = {}
                extra_json["schema"] = schema
                updates["extra_json"] = json.dumps(extra_json)

            if updates:
                session.execute(
                    sa.update(NewDataset)
                    .where(NewDataset.id == dataset_id)
                    .values(**updates)
                )
                updated = True

            if not is_physical and drivername and expression:
                table_refrences = extract_table_references(
                    expression, get_dialect_name(drivername), show_warning=False
                )
                found_tables = find_tables(
                    session, database_id, schema, table_refrences
                )
                if found_tables:
                    op.bulk_insert(
                        dataset_table_association_table,
                        [
                            {"dataset_id": dataset_id, "table_id": table.id}
                            for table in found_tables
                        ],
                    )
                    updated = True

            if updated:
                update_count += 1
                print_update_count()

        session.flush()
        offset += limit

    if SHOW_PROGRESS:
        print("")