def test_grains_dict(self):
     uri = 'mysql://root@localhost'
     database = Database(sqlalchemy_uri=uri)
     d = database.grains_dict()
     self.assertEquals(d.get('day').function, 'DATE({col})')
     self.assertEquals(d.get('P1D').function, 'DATE({col})')
     self.assertEquals(d.get('Time Column').function, '{col}')
    def test_database_schema_hive(self):
        sqlalchemy_uri = 'hive://[email protected]:10000/default?auth=NOSASL'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('default', db)

        db = make_url(model.get_sqla_engine(schema='core_db').url).database
        self.assertEquals('core_db', db)
    def test_database_schema_mysql(self):
        sqlalchemy_uri = 'mysql://root@localhost/superset'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('superset', db)

        db = make_url(model.get_sqla_engine(schema='staging').url).database
        self.assertEquals('staging', db)
    def test_database_schema_postgres(self):
        sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('prod', db)

        db = make_url(model.get_sqla_engine(schema='foo').url).database
        self.assertEquals('prod', db)
 def test_postgres_mixedcase_col_time_grain(self):
     uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
     database = Database(sqlalchemy_uri=uri)
     pdf, time_grain = '', 'P1D'
     expression, column_name = '', 'MixedCaseCol'
     grain = database.grains_dict().get(time_grain)
     col = database.db_engine_spec.get_timestamp_column(expression, column_name)
     grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
     grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
     self.assertEqual(grain_expr, grain_expr_expected)
    def test_database_impersonate_user(self):
        uri = 'mysql://root@localhost'
        example_user = '******'
        model = Database(sqlalchemy_uri=uri)

        model.impersonate_user = True
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertEquals(example_user, user_name)

        model.impersonate_user = False
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertNotEquals(example_user, user_name)
Exemple #7
0
def create_table_for_dashboard(
    df: DataFrame,
    table_name: str,
    database: Database,
    dtype: Dict[str, Any],
    table_description: str = "",
    fetch_values_predicate: Optional[str] = None,
) -> SqlaTable:
    df.to_sql(
        table_name,
        database.get_sqla_engine(),
        if_exists="replace",
        chunksize=500,
        dtype=dtype,
        index=False,
        method="multi",
    )

    table_source = ConnectorRegistry.sources["table"]
    table = (
        db.session.query(table_source).filter_by(table_name=table_name).one_or_none()
    )
    if not table:
        table = table_source(table_name=table_name)
    if fetch_values_predicate:
        table.fetch_values_predicate = fetch_values_predicate
    table.database = database
    table.description = table_description
    db.session.merge(table)
    db.session.commit()

    return table
Exemple #8
0
    def test_labels_expected_on_mutated_query(self):
        query_obj = {
            "granularity":
            None,
            "from_dttm":
            None,
            "to_dttm":
            None,
            "groupby": ["user"],
            "metrics": [{
                "expressionType": "SIMPLE",
                "column": {
                    "column_name": "user"
                },
                "aggregate": "COUNT_DISTINCT",
                "label": "COUNT_DISTINCT(user)",
            }],
            "is_timeseries":
            False,
            "filter": [],
            "extras": {},
        }

        database = Database(database_name="testdb", sqlalchemy_uri="sqlite://")
        table = SqlaTable(table_name="bq_table", database=database)
        db.session.add(database)
        db.session.add(table)
        db.session.commit()
        sqlaq = table.get_sqla_query(**query_obj)
        assert sqlaq.labels_expected == ["user", "COUNT_DISTINCT(user)"]
        sql = table.database.compile_sqla_query(sqlaq.sqla_query)
        assert "COUNT_DISTINCT_user__00db1" in sql
        db.session.delete(table)
        db.session.delete(database)
        db.session.commit()
    def test_set_perm_database(self):
        session = db.session
        database = Database(database_name="tmp_database",
                            sqlalchemy_uri="sqlite://test")
        session.add(database)

        stored_db = (session.query(Database).filter_by(
            database_name="tmp_database").one())
        self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("database_access",
                                                       stored_db.perm))

        stored_db.database_name = "tmp_database2"
        session.commit()
        stored_db = (session.query(Database).filter_by(
            database_name="tmp_database2").one())
        self.assertEqual(stored_db.perm,
                         f"[tmp_database2].(id:{stored_db.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("database_access",
                                                       stored_db.perm))

        session.delete(stored_db)
        session.commit()
def test_table_model(session: Session) -> None:
    """
    Test basic attributes of a ``Table``.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="test://"),
        columns=[
            Column(
                name="ds",
                type="TIMESTAMP",
                expression="ds",
            )
        ],
    )
    session.add(table)
    session.flush()

    assert table.id == 1
    assert table.uuid is not None
    assert table.database_id == 1
    assert table.catalog == "my_catalog"
    assert table.schema == "my_schema"
    assert table.name == "my_table"
    assert [column.name for column in table.columns] == ["ds"]
Exemple #11
0
def get_indexes_metadata(
    database: Database, table_name: str, schema_name: Optional[str]
) -> List[Dict[str, Any]]:
    indexes = database.get_indexes(table_name, schema_name)
    for idx in indexes:
        idx["type"] = "index"
    return indexes
Exemple #12
0
def test_cascade_delete_table(app_context: None, session: Session) -> None:
    """
    Test that deleting ``Table`` also deletes its columns.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    columns = session.query(Column).all()
    assert len(columns) == 2

    session.delete(table)
    session.flush()

    # test that columns were deleted
    columns = session.query(Column).all()
    assert len(columns) == 0
def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
    pdf = pd.read_json(get_example_data("birth_names.json.gz"))
    # TODO(bkyryliuk): move load examples data into the pytest fixture
    if database.backend == "presto":
        pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
        pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S")
    else:
        pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
    pdf = pdf.head(100) if sample else pdf

    pdf.to_sql(
        tbl_name,
        database.get_sqla_engine(),
        if_exists="replace",
        chunksize=500,
        dtype={
            # TODO(bkyryliuk): use TIMESTAMP type for presto
            "ds": DateTime if database.backend != "presto" else String(255),
            "gender": String(16),
            "state": String(10),
            "name": String(255),
        },
        method="multi",
        index=False,
    )
    print("Done loading table!")
    print("-" * 80)
def test_delete_sqlatable(app_context: None, session: Session) -> None:
    """
    Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    datasets = session.query(Dataset).all()
    assert len(datasets) == 1

    session.delete(sqla_table)
    session.flush()

    # test that dataset was also deleted
    datasets = session.query(Dataset).all()
    assert len(datasets) == 0
Exemple #15
0
def schema_allows_csv_upload(database: Database, schema: Optional[str]) -> bool:
    if not database.allow_csv_upload:
        return False
    schemas = database.get_schema_access_for_csv_upload()
    if schemas:
        return schema in schemas
    return security_manager.can_access_database(database)
Exemple #16
0
 def at_least_one_schema_is_allowed(database: Database) -> bool:
     """
     If the user has access to the database or all datasource
         1. if schemas_allowed_for_csv_upload is empty
             a) if database does not support schema
                 user is able to upload csv without specifying schema name
             b) if database supports schema
                 user is able to upload csv to any schema
         2. if schemas_allowed_for_csv_upload is not empty
             a) if database does not support schema
                 This situation is impossible and upload will fail
             b) if database supports schema
                 user is able to upload to schema in schemas_allowed_for_csv_upload
     elif the user does not access to the database or all datasource
         1. if schemas_allowed_for_csv_upload is empty
             a) if database does not support schema
                 user is unable to upload csv
             b) if database supports schema
                 user is unable to upload csv
         2. if schemas_allowed_for_csv_upload is not empty
             a) if database does not support schema
                 This situation is impossible and user is unable to upload csv
             b) if database supports schema
                 user is able to upload to schema in schemas_allowed_for_csv_upload
     """
     if security_manager.can_access_database(database):
         return True
     schemas = database.get_schema_access_for_csv_upload()
     if schemas and security_manager.get_schemas_accessible_by_user(
             database, schemas, False):
         return True
     return False
def load_data(data_uri: str, dataset: SqlaTable, example_database: Database,
              session: Session) -> None:
    data = request.urlopen(data_uri)  # pylint: disable=consider-using-with
    if data_uri.endswith(".gz"):
        data = gzip.open(data)
    df = pd.read_csv(data, encoding="utf-8")
    dtype = get_dtype(df, dataset)

    # convert temporal columns
    for column_name, sqla_type in dtype.items():
        if isinstance(sqla_type, (Date, DateTime)):
            df[column_name] = pd.to_datetime(df[column_name])

    # reuse session when loading data if possible, to make import atomic
    if example_database.sqlalchemy_uri == current_app.config.get(
            "SQLALCHEMY_DATABASE_URI"
    ) or not current_app.config.get("SQLALCHEMY_EXAMPLES_URI"):
        logger.info("Loading data inside the import transaction")
        connection = session.connection()
    else:
        logger.warning("Loading data outside the import transaction")
        connection = example_database.get_sqla_engine()

    df.to_sql(
        dataset.table_name,
        con=connection,
        schema=dataset.schema,
        if_exists="replace",
        chunksize=CHUNKSIZE,
        dtype=dtype,
        index=False,
        method="multi",
    )
def test_query_dao_save_metadata(app_context: None, session: Session) -> None:
    from superset.models.core import Database
    from superset.models.sql_lab import Query

    engine = session.get_bind()
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")

    query_obj = Query(
        client_id="foo",
        database=db,
        tab_name="test_tab",
        sql_editor_id="test_editor_id",
        sql="select * from bar",
        select_sql="select * from bar",
        executed_sql="select * from bar",
        limit=100,
        select_as_cta=False,
        rows=100,
        error_message="none",
        results_key="abc",
    )

    session.add(db)
    session.add(query_obj)

    from superset.queries.dao import QueryDAO

    query = session.query(Query).one()
    QueryDAO.save_metadata(query=query, payload={"columns": []})
    assert query.extra.get("columns", None) == []
Exemple #19
0
    def validate(cls, sql: str, schema: Optional[str],
                 database: Database) -> List[SQLValidationAnnotation]:
        """
        Presto supports query-validation queries by running them with a
        prepended explain.

        For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
        VALIDATE) SELECT 1 FROM default.mytable.
        """
        user_name = g.user.username if g.user and hasattr(g.user,
                                                          "username") else None
        parsed_query = ParsedQuery(sql)
        statements = parsed_query.get_statements()

        logger.info("Validating %i statement(s)", len(statements))
        engine = database.get_sqla_engine(
            schema=schema,
            nullpool=True,
            user_name=user_name,
            source=QuerySource.SQL_LAB,
        )
        # Sharing a single connection and cursor across the
        # execution of all statements (if many)
        annotations: List[SQLValidationAnnotation] = []
        with closing(engine.raw_connection()) as conn:
            cursor = conn.cursor()
            for statement in parsed_query.get_statements():
                annotation = cls.validate_statement(statement, database,
                                                    cursor, user_name)
                if annotation:
                    annotations.append(annotation)
        logger.debug("Validation found %i error(s)", len(annotations))

        return annotations
def test_table(session: Session) -> "SqlaTable":
    """
    Fixture that generates an in-memory table.
    """
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="id", type="INTEGER"),
        TableColumn(column_name="dttm", type="INTEGER"),
        TableColumn(column_name="duration_ms", type="INTEGER"),
    ]

    return SqlaTable(
        table_name="test_table",
        columns=columns,
        metrics=[],
        main_dttm_col=None,
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
    )
Exemple #21
0
def test_quote_expressions(app_context: None, session: Session) -> None:
    """
    Test that expressions are quoted appropriately in columns and datasets.
    """
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="has space", type="INTEGER"),
        TableColumn(column_name="no_need", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert dataset.expression == '"old dataset"'
    assert dataset.columns[0].expression == '"has space"'
    assert dataset.columns[1].expression == "no_need"
Exemple #22
0
def test_import_dataset_managed_externally(session: Session) -> None:
    """
    Test importing a dataset that is managed externally.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import dataset_config

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = copy.deepcopy(dataset_config)
    config["is_managed_externally"] = True
    config["external_url"] = "https://example.org/my_table"
    config["database_id"] = database.id

    sqla_table = import_dataset(session, config)
    assert sqla_table.is_managed_externally is True
    assert sqla_table.external_url == "https://example.org/my_table"
Exemple #23
0
def test_get_metrics(mocker: MockFixture) -> None:
    """
    Tests for ``get_metrics``.
    """
    from superset.db_engine_specs.base import MetricType
    from superset.db_engine_specs.sqlite import SqliteEngineSpec
    from superset.models.core import Database

    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
    assert database.get_metrics("table") == [
        {
            "expression": "COUNT(*)",
            "metric_name": "count",
            "metric_type": "count",
            "verbose_name": "COUNT(*)",
        }
    ]

    class CustomSqliteEngineSpec(SqliteEngineSpec):
        @classmethod
        def get_metrics(
            cls,
            database: Database,
            inspector: Inspector,
            table_name: str,
            schema: Optional[str],
        ) -> List[MetricType]:
            return [
                {
                    "expression": "COUNT(DISTINCT user_id)",
                    "metric_name": "count_distinct_user_id",
                    "metric_type": "count_distinct",
                    "verbose_name": "COUNT(DISTINCT user_id)",
                },
            ]

    database.get_db_engine_spec_for_backend = mocker.MagicMock(  # type: ignore
        return_value=CustomSqliteEngineSpec
    )
    assert database.get_metrics("table") == [
        {
            "expression": "COUNT(DISTINCT user_id)",
            "metric_name": "count_distinct_user_id",
            "metric_type": "count_distinct",
            "verbose_name": "COUNT(DISTINCT user_id)",
        },
    ]
Exemple #24
0
def get_foreign_keys_metadata(
        database: Database, table_name: str,
        schema_name: Optional[str]) -> List[Dict[str, Any]]:
    foreign_keys = database.get_foreign_keys(table_name, schema_name)
    for fk in foreign_keys:
        fk["column_names"] = fk.pop("constrained_columns")
        fk["type"] = "fk"
    return foreign_keys
 def setUp(self):
     super(SqlaConnectorTestCase, self).setUp()
     sqlalchemy_uri = 'sqlite:////tmp/test.db'
     database = Database(
         database_name='test_database',
         sqlalchemy_uri=sqlalchemy_uri)
     self.connection = database.get_sqla_engine().connect()
     self.datasource = SqlaTable(table_name='test_datasource',
                                 database=database,
                                 columns=self.columns,
                                 metrics=self.metrics)
     with database.get_sqla_engine().begin() as connection:
         self.df.to_sql(self.datasource.table_name,
                        connection,
                        if_exists='replace',
                        index=False,
                        dtype={'received': Date})
def import_from_dict(session, data, sync=[]):
    """Imports databases and druid clusters from dictionary"""
    if isinstance(data, dict):
        logging.info('Importing %d %s',
                     len(data.get(DATABASES_KEY, [])),
                     DATABASES_KEY)
        for database in data.get(DATABASES_KEY, []):
            Database.import_from_dict(session, database, sync=sync)

        logging.info('Importing %d %s',
                     len(data.get(DRUID_CLUSTERS_KEY, [])),
                     DRUID_CLUSTERS_KEY)
        for datasource in data.get(DRUID_CLUSTERS_KEY, []):
            DruidCluster.import_from_dict(session, datasource, sync=sync)
        session.commit()
    else:
        logging.info('Supplied object is not a dictionary.')
Exemple #27
0
 def select_star(self,
                 database: Database,
                 table_name: str,
                 schema_name: Optional[str] = None) -> FlaskResponse:
     """ Table schema info
     ---
     get:
       description: Get database select star for table
       parameters:
       - in: path
         schema:
           type: integer
         name: pk
         description: The database id
       - in: path
         schema:
           type: string
         name: table_name
         description: Table name
       - in: path
         schema:
           type: string
         name: schema_name
         description: Table schema
       responses:
         200:
           description: select star for table
           content:
             text/plain:
               schema:
                 type: object
                 properties:
                   result:
                     type: string
                     description: SQL select star
         400:
           $ref: '#/components/responses/400'
         401:
           $ref: '#/components/responses/401'
         404:
           $ref: '#/components/responses/404'
         422:
           $ref: '#/components/responses/422'
         500:
           $ref: '#/components/responses/500'
     """
     self.incr_stats("init", self.select_star.__name__)
     try:
         result = database.select_star(table_name,
                                       schema_name,
                                       latest_partition=True,
                                       show_cols=True)
     except NoSuchTableError:
         self.incr_stats("error", self.select_star.__name__)
         return self.response(404,
                              message="Table not found on the database")
     self.incr_stats("success", self.select_star.__name__)
     return self.response(200, result=result)
    def test_apply_limit(self):
        def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
            return str(
                qry.compile(
                    dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
                )
            )

        database = Database(
            database_name="mssql_test",
            sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
        )
        db.session.add(database)
        db.session.commit()

        with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
            test_sql = "SELECT COUNT(*) FROM FOO_TABLE"

            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

            expected_sql = (
                "SELECT TOP 1000 * \n"
                "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
            )
            self.assertEqual(expected_sql, limited_sql)

            test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

            expected_sql = (
                "SELECT TOP 1000 * \n"
                "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
                "AS inner_qry"
            )
            self.assertEqual(expected_sql, limited_sql)

            test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)

            expected_sql = (
                "SELECT TOP 1000 * \n"
                "FROM (SELECT COUNT(*) AS COUNT_1, "
                "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
                " AS inner_qry"
            )
            self.assertEqual(expected_sql, limited_sql)

            test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
            limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
            expected_sql = (
                "SELECT TOP 1000 * \n"
                "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
                " AS inner_qry"
            )
            self.assertEqual(expected_sql, limited_sql)

        db.session.delete(database)
        db.session.commit()
def test_update_sqlatable(mocker: MockFixture, app_context: None,
                          session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # add a column to the original ``SqlaTable`` instance
    sqla_table.columns.append(
        TableColumn(column_name="user_id", type="INTEGER"))
    session.flush()

    # check that the column was added to the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 2

    # delete the column in the original instance
    sqla_table.columns = sqla_table.columns[1:]
    session.flush()

    # check that the column was also removed from the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # modify the attribute in a column
    sqla_table.columns[0].is_dttm = True
    session.flush()

    # check that the dataset column was modified
    dataset = session.query(Dataset).one()
    assert dataset.columns[0].is_temporal is True
Exemple #30
0
def create_test_table_context(database: Database):
    database.get_sqla_engine().execute(
        "CREATE TABLE test_table AS SELECT 1 as first, 2 as second")
    database.get_sqla_engine().execute(
        "INSERT INTO test_table (first, second) VALUES (1, 2)")
    database.get_sqla_engine().execute(
        "INSERT INTO test_table (first, second) VALUES (3, 4)")

    yield db.session
    database.get_sqla_engine().execute("DROP TABLE test_table")
 def test_alter_new_orm_column(self):
     """
     DB Eng Specs (crate): Test alter orm column
     """
     database = Database(database_name="crate", sqlalchemy_uri="crate://db")
     tbl = SqlaTable(table_name="druid_tbl", database=database)
     col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
     CrateEngineSpec.alter_new_orm_column(col)
     assert col.python_date_format == "epoch_ms"
Exemple #32
0
    def enable_csv_upload(self, database: models.Database) -> None:
        """Enables csv upload in the given database."""
        database.allow_csv_upload = True
        db.session.commit()
        add_datasource_page = self.get_resp("/databaseview/list/")
        self.assertIn("Upload a CSV", add_datasource_page)

        form_get = self.get_resp("/csvtodatabaseview/form")
        self.assertIn("CSV to Database configuration", form_get)
 def build_db_for_connection_test(server_cert: str, extra: str,
                                  impersonate_user: bool,
                                  encrypted_extra: str) -> Database:
     return Database(
         server_cert=server_cert,
         extra=extra,
         impersonate_user=impersonate_user,
         encrypted_extra=encrypted_extra,
     )
    def test_database_connection_test_mutator(self):
        database = Database(sqlalchemy_uri="snowflake://abc")
        SnowflakeEngineSpec.mutate_db_for_connection_test(database)
        engine_params = json.loads(database.extra or "{}")

        self.assertDictEqual(
            {"engine_params": {"connect_args": {"validate_default_parameters": True}}},
            engine_params,
        )
Exemple #35
0
def import_from_dict(data: Dict[str, Any],
                     sync: Optional[List[str]] = None) -> None:
    """Imports databases and druid clusters from dictionary"""
    if not sync:
        sync = []
    if isinstance(data, dict):
        logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])),
                    DATABASES_KEY)
        for database in data.get(DATABASES_KEY, []):
            Database.import_from_dict(database, sync=sync)

        logger.info("Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])),
                    DRUID_CLUSTERS_KEY)
        for datasource in data.get(DRUID_CLUSTERS_KEY, []):
            DruidCluster.import_from_dict(datasource, sync=sync)
        db.session.commit()
    else:
        logger.info("Supplied object is not a dictionary.")
Exemple #36
0
def test_dataset_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    assert dataset.id == 1
    assert dataset.uuid is not None

    assert dataset.name == "positions"
    assert (
        dataset.expression
        == """
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
"""
    )

    assert [table.name for table in dataset.tables] == ["my_table"]
    assert [column.name for column in dataset.columns] == ["position"]
Exemple #37
0
def apply_limit_if_exists(database: Database, increased_limit: Optional[int],
                          query: Query, sql: str) -> str:
    if query.limit and increased_limit:
        # We are fetching one more than the requested limit in order
        # to test whether there are more rows than the limit.
        # Later, the extra row will be dropped before sending
        # the results back to the user.
        sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
    return sql
def export_schema_to_dict(back_references):
    """Exports the supported import/export schema to a dictionary"""
    databases = [Database.export_schema(recursive=True,
                 include_parent_ref=back_references)]
    clusters = [DruidCluster.export_schema(recursive=True,
                include_parent_ref=back_references)]
    data = dict()
    if databases:
        data[DATABASES_KEY] = databases
    if clusters:
        data[DRUID_CLUSTERS_KEY] = clusters
    return data
Exemple #39
0
    def test_database_schema_presto(self):
        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('hive/default', db)

        db = make_url(model.get_sqla_engine(schema='core_db').url).database
        self.assertEquals('hive/core_db', db)

        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('hive', db)

        db = make_url(model.get_sqla_engine(schema='core_db').url).database
        self.assertEquals('hive/core_db', db)
Exemple #40
0
    def test_database_for_various_backend(self):
        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'hive/default'
        db = model.get_database_for_various_backend(url, 'raw_data')
        assert db == 'hive/raw_data'

        sqlalchemy_uri = 'redshift+psycopg2://superset:[email protected]:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'prod'
        db = model.get_database_for_various_backend(url, 'test')
        assert db == 'prod'

        sqlalchemy_uri = 'postgresql+psycopg2://superset:[email protected]:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'prod'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'prod'

        sqlalchemy_uri = 'hive://[email protected]:10000/raw_data'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'raw_data'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'adhoc'

        sqlalchemy_uri = 'mysql://*****:*****@mysql.airbnb.io/superset'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'superset'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'adhoc'