コード例 #1
0
ファイル: blob_api.py プロジェクト: tayloris/ert
class BlobApi:
    def __init__(self, connection):
        self._session = Session(bind=connection)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def commit(self):
        self._session.commit()

    def flush(self):
        self._session.flush()

    def rollback(self):
        self._session.rollback()

    def close(self):
        self._session.close()

    def close_connection(self):
        self._session.connection().close()

    def add_blob(self, data):
        data_frame = ErtBlob(data=data)
        self._session.add(data_frame)
        return data_frame

    def get_blob(self, id):
        return self._session.query(ErtBlob).get(id)

    def get_blobs(self, ids):
        if not isinstance(ids, list):
            ids = [ids]

        return (self._session.query(ErtBlob).filter(
            ErtBlob.id.in_(ids)).yield_per(1).enable_eagerloads(False))
コード例 #2
0
def _inspect_schema_version(session: Session) -> int:
    """Determine the schema version by inspecting the db structure.

    When the schema version is not present in the db, either db was just
    created with the correct schema, or this is a db created before schema
    versions were tracked. For now, we'll test if the changes for schema
    version 1 are present to make the determination. Eventually this logic
    can be removed and we can assume a new db is being created.
    """
    inspector = sqlalchemy.inspect(session.connection())
    indexes = inspector.get_indexes("events")

    for index in indexes:
        if index["column_names"] == ["time_fired"]:
            # Schema addition from version 1 detected. New DB.
            session.add(StatisticsRuns(start=get_start_time()))
            session.add(SchemaChanges(schema_version=SCHEMA_VERSION))
            return SCHEMA_VERSION

    # Version 1 schema changes not found, this db needs to be migrated.
    current_version = SchemaChanges(schema_version=0)
    session.add(current_version)
    return cast(int, current_version.schema_version)
コード例 #3
0
ファイル: db.py プロジェクト: mhenc/airflow
def downgrade(to_revision,
              sql=False,
              from_revision=None,
              session: Session = NEW_SESSION):
    """
    Downgrade the airflow metastore schema to a prior version.

    :param to_revision: The alembic revision to downgrade *to*.
    :param sql: if True, print sql statements but do not run them
    :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only
        be used in conjunction with ``sql=True`` because if we actually run the commands,
        we should only downgrade from the *current* revision.
    :param session: sqlalchemy session for connection to airflow metadata database
    """
    if from_revision and not sql:
        raise ValueError(
            "`from_revision` can't be combined with `sql=False`. When actually "
            "applying a downgrade (instead of just generating sql), we always "
            "downgrade from current revision.")

    if not settings.SQL_ALCHEMY_CONN:
        raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")

    # alembic adds significant import time, so we import it lazily
    from alembic import command

    log.info("Attempting downgrade to revision %s", to_revision)

    config = _get_alembic_config()

    config.set_main_option('sqlalchemy.url',
                           settings.SQL_ALCHEMY_CONN.replace('%', '%%'))

    errors_seen = False
    for err in _check_migration_errors(session=session):
        if not errors_seen:
            log.error(
                "Automatic migration failed.  You may need to apply downgrades manually.  "
            )
            errors_seen = True
        log.error("%s", err)

    if errors_seen:
        exit(1)

    with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
        if sql:
            log.warning("Generating sql scripts for manual migration.")

            conn = session.connection()

            from alembic.migration import MigrationContext

            migration_ctx = MigrationContext.configure(conn)
            if not from_revision:
                from_revision = migration_ctx.get_current_revision()
            revision_range = f"{from_revision}:{to_revision}"
            _offline_migration(command.downgrade,
                               config=config,
                               revision=revision_range)
        else:
            log.info("Applying downgrade migrations.")
            command.downgrade(config, revision=to_revision, sql=sql)
コード例 #4
0
def test_sql_lab_insert_rls(
    mocker: MockerFixture,
    session: Session,
    app_context: None,
) -> None:
    """
    Integration test for `insert_rls`.
    """
    from flask_appbuilder.security.sqla.models import Role, User

    from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
    from superset.models.core import Database
    from superset.models.sql_lab import Query
    from superset.security.manager import SupersetSecurityManager
    from superset.sql_lab import execute_sql_statement
    from superset.utils.core import RowLevelSecurityFilterType

    engine = session.connection().engine
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    connection = engine.raw_connection()
    connection.execute("CREATE TABLE t (c INTEGER)")
    for i in range(10):
        connection.execute("INSERT INTO t VALUES (?)", (i, ))

    cursor = connection.cursor()

    query = Query(
        sql="SELECT c FROM t",
        client_id="abcde",
        database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
        schema=None,
        limit=5,
        select_as_cta_used=False,
    )
    session.add(query)
    session.commit()

    admin = User(
        first_name="Alice",
        last_name="Doe",
        email="*****@*****.**",
        username="******",
        roles=[Role(name="Admin")],
    )

    # first without RLS
    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   0 |
|  1 |   1 |
|  2 |   2 |
|  3 |   3 |
|  4 |   4 |""".strip())
    assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"

    # now with RLS
    rls = RowLevelSecurityFilter(
        filter_type=RowLevelSecurityFilterType.REGULAR,
        tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
        roles=[admin.roles[0]],
        group_key=None,
        clause="c > 5",
    )
    session.add(rls)
    session.flush()
    mocker.patch.object(SupersetSecurityManager,
                        "find_user",
                        return_value=admin)
    mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   6 |
|  1 |   7 |
|  2 |   8 |
|  3 |   9 |""".strip())
    assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"