Esempio n. 1
0
def test_post_with_uuid(
    mocker: MockFixture,
    app_context: None,
    session: Session,
    client: Any,
    full_api_access: None,
) -> None:
    """
    Test that we can set the database UUID when creating it.
    """
    from superset.models.core import Database

    # create table for databases
    Database.metadata.create_all(session.get_bind())  # pylint: disable=no-member

    response = client.post(
        "/api/v1/database/",
        json={
            "database_name": "my_db",
            "sqlalchemy_uri": "sqlite://",
            "uuid": "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb",
        },
    )
    assert response.status_code == 201

    database = session.query(Database).one()
    assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        bind_mode_context = _current_bind_mode_context()
        current_mode = getattr(bind_mode_context, 'current_mode', None)

        # 1) When no mode is explicitly specified, and SELECT based operations
        # is being made.
        # 2) When _SLAVE is explicitly specified
        # Then use slave connection
        if ((current_mode is None and isinstance(clause, Select))
                or current_mode is _SLAVE):
            state = get_state(self.app)
            slaves = self.app.config['SQLALCHEMY_DATABASE_SLAVE_URIS']
            if slaves:
                random_index = random.randrange(len(slaves))
                bind_key = 'slaves_{}'.format(random_index)
            else:
                bind_key = None
            return state.db.get_engine(self.app, bind=bind_key)

        return SessionBase.get_bind(self, mapper, clause)
Esempio n. 3
0
def test_table_model(session: Session) -> None:
    """
    Test basic attributes of a ``Table``.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="test://"),
        columns=[
            Column(
                name="ds",
                type="TIMESTAMP",
                expression="ds",
            )
        ],
    )
    session.add(table)
    session.flush()

    assert table.id == 1
    assert table.uuid is not None
    assert table.database_id == 1
    assert table.catalog == "my_catalog"
    assert table.schema == "my_schema"
    assert table.name == "my_table"
    assert [column.name for column in table.columns] == ["ds"]
Esempio n. 4
0
def test_table(session: Session) -> "SqlaTable":
    """
    Fixture that generates an in-memory table.
    """
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="id", type="INTEGER"),
        TableColumn(column_name="dttm", type="INTEGER"),
        TableColumn(column_name="duration_ms", type="INTEGER"),
    ]

    return SqlaTable(
        table_name="test_table",
        columns=columns,
        metrics=[],
        main_dttm_col=None,
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
Esempio n. 5
0
def test_import_dataset_managed_externally(app_context: None,
                                           session: Session) -> None:
    """
    Test importing a dataset that is managed externally.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import dataset_config

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = copy.deepcopy(dataset_config)
    config["is_managed_externally"] = True
    config["external_url"] = "https://example.org/my_table"
    config["database_id"] = database.id

    sqla_table = import_dataset(session, config)
    assert sqla_table.is_managed_externally is True
    assert sqla_table.external_url == "https://example.org/my_table"
Esempio n. 6
0
    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        bind_mode_context = _current_bind_mode_context()
        current_mode = getattr(bind_mode_context, 'current_mode', None)

        # 1) When no mode is explicitly specified, and SELECT based operations
        # is being made.
        # 2) When _SLAVE is explicitly specified
        # Then use slave connection
        if ((current_mode is None and isinstance(clause, Select)) or
                current_mode is _SLAVE):
            state = get_state(self.app)
            slaves = self.app.config['SQLALCHEMY_DATABASE_SLAVE_URIS']
            if slaves:
                random_index = random.randrange(len(slaves))
                bind_key = 'slaves_{}'.format(random_index)
            else:
                bind_key = None
            return state.db.get_engine(self.app, bind=bind_key)

        return SessionBase.get_bind(self, mapper, clause)
Esempio n. 7
0
def create_global_lock(session: Session, lock: DBLocks, lock_timeout=1800):
    """Contextmanager that will create and teardown a global db lock."""
    conn = session.get_bind().connect()
    dialect = conn.dialect
    try:
        if dialect.name == 'postgresql':
            conn.execute(text('SET LOCK_TIMEOUT to :timeout'),
                         timeout=lock_timeout)
            conn.execute(text('SELECT pg_advisory_lock(:id)'), id=lock.value)
        elif dialect.name == 'mysql' and dialect.server_version_info >= (5, 6):
            conn.execute(text("SELECT GET_LOCK(:id, :timeout)"),
                         id=str(lock),
                         timeout=lock_timeout)
        elif dialect.name == 'mssql':
            # TODO: make locking work for MSSQL
            pass

        yield
    finally:
        if dialect.name == 'postgresql':
            conn.execute('SET LOCK_TIMEOUT TO DEFAULT')
            (unlocked, ) = conn.execute(text('SELECT pg_advisory_unlock(:id)'),
                                        id=lock.value).fetchone()
            if not unlocked:
                raise RuntimeError("Error releasing DB lock!")
        elif dialect.name == 'mysql' and dialect.server_version_info >= (5, 6):
            conn.execute(text("select RELEASE_LOCK(:id)"), id=str(lock))
        elif dialect.name == 'mssql':
            # TODO: make locking work for MSSQL
            pass
Esempio n. 8
0
def test_import_database(session: Session) -> None:
    """
    Test importing a database.
    """
    from superset.databases.commands.importers.v1.utils import import_database
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import database_config

    engine = session.get_bind()
    Database.metadata.create_all(engine)  # pylint: disable=no-member

    config = copy.deepcopy(database_config)
    database = import_database(session, config)
    assert database.database_name == "imported_database"
    assert database.sqlalchemy_uri == "sqlite:///test.db"
    assert database.cache_timeout is None
    assert database.expose_in_sqllab is True
    assert database.allow_run_async is False
    assert database.allow_ctas is True
    assert database.allow_cvas is True
    assert database.allow_file_upload is True
    assert database.extra == "{}"
    assert database.uuid == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89"
    assert database.is_managed_externally is False
    assert database.external_url is None
Esempio n. 9
0
def test_cascade_delete_table(app_context: None, session: Session) -> None:
    """
    Test that deleting ``Table`` also deletes its columns.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    columns = session.query(Column).all()
    assert len(columns) == 2

    session.delete(table)
    session.flush()

    # test that columns were deleted
    columns = session.query(Column).all()
    assert len(columns) == 0
Esempio n. 10
0
def test_quote_expressions(app_context: None, session: Session) -> None:
    """
    Test that expressions are quoted appropriately in columns and datasets.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="has space", type="INTEGER"),
        TableColumn(column_name="no_need", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert dataset.expression == '"old dataset"'
    assert dataset.columns[0].expression == '"has space"'
    assert dataset.columns[1].expression == "no_need"
Esempio n. 11
0
def check_run_id_null(session: Session) -> Iterable[str]:
    import sqlalchemy.schema

    metadata = sqlalchemy.schema.MetaData(session.bind)
    try:
        metadata.reflect(only=[DagRun.__tablename__], extend_existing=True, resolve_fks=False)
    except exc.InvalidRequestError:
        # Table doesn't exist -- empty db
        return

    # We can't use the model here since it may differ from the db state due to
    # this function is run prior to migration. Use the reflected table instead.
    dagrun_table = metadata.tables[DagRun.__tablename__]

    invalid_dagrun_filter = or_(
        dagrun_table.c.dag_id.is_(None),
        dagrun_table.c.run_id.is_(None),
        dagrun_table.c.execution_date.is_(None),
    )
    invalid_dagrun_count = session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count()
    if invalid_dagrun_count > 0:
        dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2")
        if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
            yield _format_dangling_error(
                source_table=dagrun_table.name,
                target_table=dagrun_dangling_table_name,
                invalid_count=invalid_dagrun_count,
                reason="with a NULL dag_id, run_id, or execution_date",
            )
            return
        _move_dangling_run_data_to_new_table(session, dagrun_table, dagrun_dangling_table_name)
Esempio n. 12
0
def test_delete_sqlatable(app_context: None, session: Session) -> None:
    """
    Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    datasets = session.query(Dataset).all()
    assert len(datasets) == 1

    session.delete(sqla_table)
    session.flush()

    # test that dataset was also deleted
    datasets = session.query(Dataset).all()
    assert len(datasets) == 0
Esempio n. 13
0
class _SignallingSession(Session):
    def __init__(self, db, autocommit=False, autoflush=False, **options):
        self.app = db.get_app()
        self._model_changes = {}
        Session.__init__(self,
                         autocommit=autocommit,
                         autoflush=autoflush,
                         bind=db.engine,
                         binds=db.get_binds(self.app),
                         **options)

    def get_bind(self, mapper, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)
        # if reayonly use the slave engine
        # see http://docs.sqlalchemy.org/en/rel_0_8/orm/session.html#custom-vertical-partitioning
        if not self._flushing:
            state = get_state(self.app)
            try:
                return state.db.get_engine(self.app, bind='slave')
            except Exception, e:
                pass
        return Session.get_bind(self, mapper, clause)
Esempio n. 14
0
def test_column_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Column``.
    """
    from superset.columns.models import Column

    engine = session.get_bind()
    Column.metadata.create_all(engine)  # pylint: disable=no-member

    column = Column(
        name="ds",
        type="TIMESTAMP",
        expression="ds",
    )

    session.add(column)
    session.flush()

    assert column.id == 1
    assert column.uuid is not None

    assert column.name == "ds"
    assert column.type == "TIMESTAMP"
    assert column.expression == "ds"

    # test that default values are set correctly
    assert column.description is None
    assert column.warning_text is None
    assert column.unit is None
    assert column.is_temporal is False
    assert column.is_spatial is False
    assert column.is_partition is False
    assert column.is_aggregation is False
    assert column.is_additive is False
    assert column.is_increase_desired is True
Esempio n. 15
0
def exists_orm(session: Session,
               ormclass: DeclarativeMeta,
               *criteria: Any) -> bool:
    # http://docs.sqlalchemy.org/en/latest/orm/query.html
    q = session.query(ormclass)
    for criterion in criteria:
        q = q.filter(criterion)

    # See this:
    # - https://bitbucket.org/zzzeek/sqlalchemy/issues/3212/misleading-documentation-for-queryexists  # noqa
    # - http://docs.sqlalchemy.org/en/latest/orm/query.html#sqlalchemy.orm.query.Query.exists  # noqa

    exists_clause = q.exists()
    if session.get_bind().dialect.name == 'mssql':
        # SQL Server
        result = session.query(literal(True)).filter(exists_clause).scalar()
        # SELECT 1 WHERE EXISTS (SELECT 1 FROM table WHERE ...)
        # ... giving 1 or None (no rows)
        # ... fine for SQL Server, but invalid for MySQL (no FROM clause)
    else:
        # MySQL, etc.
        result = session.query(exists_clause).scalar()
        # SELECT EXISTS (SELECT 1 FROM table WHERE ...)
        # ... giving 1 or 0
        # ... fine for MySQL, but invalid syntax for SQL server
    return bool(result)
Esempio n. 16
0
def test_query_dao_save_metadata(session: Session) -> None:
    from superset.models.core import Database
    from superset.models.sql_lab import Query

    engine = session.get_bind()
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")

    query_obj = Query(
        client_id="foo",
        database=db,
        tab_name="test_tab",
        sql_editor_id="test_editor_id",
        sql="select * from bar",
        select_sql="select * from bar",
        executed_sql="select * from bar",
        limit=100,
        select_as_cta=False,
        rows=100,
        error_message="none",
        results_key="abc",
    )

    session.add(db)
    session.add(query_obj)

    from superset.queries.dao import QueryDAO

    query = session.query(Query).one()
    QueryDAO.save_metadata(query=query, payload={"columns": []})
    assert query.extra.get("columns", None) == []
Esempio n. 17
0
 def get_bind(self, mapper, clause=None):
     if mapper is not None:
         key = mapper.mapped_table.name
         if isinstance(clause, Select):
             # get slave engine
             return self.db.get_table_engine(key, slave=True)
         return self.db.get_table_engine(key, slave=False)
     return SessionBase.get_bind(self, mapper, clause)
Esempio n. 18
0
def test_update_sqlatable(mocker: MockFixture, app_context: None,
                          session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # add a column to the original ``SqlaTable`` instance
    sqla_table.columns.append(
        TableColumn(column_name="user_id", type="INTEGER"))
    session.flush()

    # check that the column was added to the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 2

    # delete the column in the original instance
    sqla_table.columns = sqla_table.columns[1:]
    session.flush()

    # check that the column was also removed from the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # modify the attribute in a column
    sqla_table.columns[0].is_dttm = True
    session.flush()

    # check that the dataset column was modified
    dataset = session.query(Dataset).one()
    assert dataset.columns[0].is_temporal is True
Esempio n. 19
0
 def get_bind(self, mapper, clause=None):
     # mapper is None if someone tries to just get a connection
     if mapper is not None:
         info = getattr(mapper.mapped_table, 'info', {})
         bind_key = info.get('bind_key')
         if bind_key is not None:
             state = get_state(self.app)
             return state.db.get_engine(self.app, bind=bind_key)
         elif self._flushing:
             self.autocommit = self._autocommit
             return Session.get_bind(self, mapper, clause)
         else:
             state = get_state(self.app)
             slaves = [x for x in self.app.config.get('SQLALCHEMY_BINDS', dict()).keys() if x.startswith('slave_')]
             if slaves:
                 self.autocommit = True
                 return state.db.get_engine(self.app, bind=random.choice(slaves))
     return Session.get_bind(self, mapper, clause)
Esempio n. 20
0
 def get_bind(self, mapper, clause=None):
     # mapper is None if someone tries to just get a connection
     if mapper is not None:
         info = getattr(mapper.mapped_table, 'info', {})
         bind_key = info.get('bind_key')
         if bind_key is not None:
             state = get_state(self.app)
             return state.db.get_engine(self.app, bind=bind_key)
     return SessionBase.get_bind(self, mapper, clause)
Esempio n. 21
0
def test_dataset_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    assert dataset.id == 1
    assert dataset.uuid is not None

    assert dataset.name == "positions"
    assert (
        dataset.expression
        == """
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
"""
    )

    assert [table.name for table in dataset.tables] == ["my_table"]
    assert [column.name for column in dataset.columns] == ["position"]
Esempio n. 22
0
 def session_close(self, session: Session):
     """
     session必须手工在finally里关闭
     :param session:
     :return:
     """
     eng = session.get_bind()
     session.close()
     if eng is not None:
         eng.dispose()
Esempio n. 23
0
def test_cascade_delete_dataset(app_context: None, session: Session) -> None:
    """
    Test that deleting ``Dataset`` also deletes its columns.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        database=table.database,
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    columns = session.query(Column).all()
    assert len(columns) == 3

    session.delete(dataset)
    session.flush()

    # test that dataset columns were deleted (but not table columns)
    columns = session.query(Column).all()
    assert len(columns) == 2
Esempio n. 24
0
 def session_execute(self, session: Session, sql: str, params=None):
     """
     session执行sql
     :param session:
     :param sql:
     :param params:
     :return:
     """
     session.execute(
         sql,
         self._prepare_params_of_execute_sql(session.get_bind(), sql,
                                             params))
Esempio n. 25
0
    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, "info", {})
            bind_key = info.get("bind_key")
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        if not hasattr(using_master, "disabled") or not using_master.disabled:
            return SessionBase.get_bind(self, mapper, clause)
        else:
            state = get_state(self.app)
            return state.db.get_engine(self.app, bind="slaves")
Esempio n. 26
0
    def get_bind(self, mapper=None, clause=None):
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            engine_key = info.get('bind_key')

            if engine_key is not None:
                engine = registry.get_engine(engine_key)

                if not engine:
                    raise InvalidBindKey(engine_key)

                return engine

        return BaseSession.get_bind(self, mapper, clause)
Esempio n. 27
0
    def get_bind(self, mapper=None, clause=None):
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            engine_key = info.get('bind_key')

            if engine_key is not None:
                engine = registry.get_engine(engine_key)

                if not engine:
                    raise InvalidBindKey(engine_key)

                return engine

        return BaseSession.get_bind(self, mapper, clause)
Esempio n. 28
0
 def session_all_row(self, session: Session, sql, params=None) -> CDataSet:
     """
     执行sql, 返回所有符合要求的记录
     :param session:
     :param sql:
     :param params:
     :return:
     """
     cursor = session.execute(
         sql,
         self._prepare_params_of_execute_sql(session.get_bind(), sql,
                                             params))
     data = cursor.fetchall()
     return CDataSet(data)
    def get_bind(self, mapper, clause=None):
        # mapper is None if someone tries to just get a connection

        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        if isinstance(clause, sqlalchemy.sql.expression.Select):
            state = get_state(self.app)
            return state.db.get_engine(self.app, bind='__slave__') 

        return Session.get_bind(self, mapper, clause)
Esempio n. 30
0
    def get_bind(self, mapper, clause=None):
        # mapper is None if someone tries to just get a connection

        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        if isinstance(clause, sqlalchemy.sql.expression.Select):
            state = get_state(self.app)
            return state.db.get_engine(self.app, bind='__slave__')

        return Session.get_bind(self, mapper, clause)
Esempio n. 31
0
def exists_plain(session: Session, tablename: str, *criteria: Any) -> bool:
    exists_clause = exists().select_from(table(tablename))
    # ... EXISTS (SELECT * FROM tablename)
    for criterion in criteria:
        exists_clause = exists_clause.where(criterion)
    # ... EXISTS (SELECT * FROM tablename WHERE ...)

    if session.get_bind().dialect.name == 'mssql':
        query = select([literal(True)]).where(exists_clause)
        # ... SELECT 1 WHERE EXISTS (SELECT * FROM tablename WHERE ...)
    else:
        query = select([exists_clause])
        # ... SELECT EXISTS (SELECT * FROM tablename WHERE ...)

    result = session.execute(query).scalar()
    return bool(result)
Esempio n. 32
0
    def get_bind(self, mapper=None, clause=None):
        try:
            state = get_state(self.app)
        except (AssertionError, AttributeError, TypeError) as err:
            log.info(
                "Unable to get Flask-SQLAlchemy configuration. Outputting default bind. Error:"
                + err)
            return orm.Session.get_bind(self, mapper, clause)

        # If there are no binds configured, connect using the default SQLALCHEMY_DATABASE_URI
        if state is None or not self.app.config['SQLALCHEMY_BINDS']:
            if not self.app.debug:
                log.debug(
                    "Connecting -> DEFAULT. Unable to get Flask-SQLAlchemy bind configuration. Outputting default bind."
                )
            return orm.Session.get_bind(self, mapper, clause)

        has_bind = read_or_write.__dict__.has_key('read_or_write')
        if has_bind and read_or_write.__dict__['read_or_write'] == 'read':

            if 'read_type' in read_or_write.__dict__ and read_or_write.__dict__[
                    'read_type'] == 'force_read':
                bind_key = read_or_write.__dict__['bind_key']
                force_db_name = read_or_write.__dict__['force_db_name']
                mapper_db_name = self.get_dbname(mapper, clause)
                if mapper_db_name == bind_key:
                    # 若强制指定的从库,则走指定的从库
                    db_name = force_db_name
                else:
                    db_name = mapper_db_name
            else:
                db_name = self.get_dbname(mapper, clause)

            if len(self.db_map[db_name]) == 0:
                # 抛异常,@read不能没有_slave_
                raise Exception('请配置从库:' + db_name)
                pass

            slave_for_read = self.get_slave(db_name)
            return state.db.get_engine(self.app, bind=slave_for_read)
        else:
            if mapper is not None:
                info = getattr(mapper.mapped_table, 'info', {})
                bind_key = info.get('bind_key')
                if bind_key is not None:
                    return state.db.get_engine(self.app, bind=bind_key)
            return SessionBase.get_bind(self, mapper, clause)
Esempio n. 33
0
def test_update_sqlatable_metric(mocker: MockFixture, app_context: None,
                                 session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.

    For this test we check that updating the SQL expression in a metric belonging to a
    ``SqlaTable`` is reflected in the ``Dataset`` metric.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    metrics = [
        SqlMetric(metric_name="cnt", expression="COUNT(*)"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    # check that the metric was created
    column = session.query(Column).filter_by(is_physical=False).one()
    assert column.expression == "COUNT(*)"

    # change the metric definition
    sqla_table.metrics[0].expression = "MAX(ds)"
    session.flush()

    assert column.expression == "MAX(ds)"
Esempio n. 34
0
def test_import_database_managed_externally(session: Session) -> None:
    """
    Test importing a database that is managed externally.
    """
    from superset.databases.commands.importers.v1.utils import import_database
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import database_config

    engine = session.get_bind()
    Database.metadata.create_all(engine)  # pylint: disable=no-member

    config = copy.deepcopy(database_config)
    config["is_managed_externally"] = True
    config["external_url"] = "https://example.org/my_database"

    database = import_database(session, config)
    assert database.is_managed_externally is True
    assert database.external_url == "https://example.org/my_database"
Esempio n. 35
0
def session_with_data(session: Session) -> Iterator[Session]:
    from superset.models.slice import Slice

    engine = session.get_bind()
    Slice.metadata.create_all(engine)  # pylint: disable=no-member

    slice_obj = Slice(
        id=1,
        datasource_id=1,
        datasource_type=DatasourceType.TABLE,
        datasource_name="tmp_perm_table",
        slice_name="slice_name",
    )

    session.add(slice_obj)
    session.commit()
    yield session
    session.rollback()
Esempio n. 36
0
 def session_one_row(self, session: Session, sql, params=None) -> CDataSet:
     """
     执行sql, 返回第一行符合要求的记录
     :param session:
     :param sql:
     :param params:
     :return:
     """
     cursor = session.execute(
         sql,
         self._prepare_params_of_execute_sql(session.get_bind(), sql,
                                             params))
     data = cursor.fetchone()
     if data is None:
         return CDataSet()
     else:
         row_data = [data]
         return CDataSet(row_data)
Esempio n. 37
0
    def get_bind(self, mapper=None, clause=None):
        """Return the engine or connection for a given model or
        table, using the ``__bind_key__`` if it is set.
        """
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            try:
                # SA >= 1.3
                persist_selectable = mapper.persist_selectable
            except AttributeError:
                # SA < 1.3
                persist_selectable = mapper.mapped_table

            info = getattr(persist_selectable, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)
        return SessionBase.get_bind(self, mapper, clause)
Esempio n. 38
0
    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        if not isinstance(clause, Select):
            return SessionBase.get_bind(self, mapper, clause)
        else:
            state = get_state(self.app)
            slaves = self.app.config['SQLALCHEMY_DATABASE_SLAVE_URIS']
            if slaves:
                random_index = random.randrange(len(slaves))
                bind = 'slaves_{}'.format(random_index)
            else:
                bind = None
            return state.db.get_engine(self.app, bind=bind)