def test_import_table_override_sync(self):
        table, dict_table = self.create_table(
            'table_override', id=ID_PREFIX + 3,
            cols_names=['col1'],
            metric_names=['m1'])
        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
        db.session.commit()
        table_over, dict_table_over = self.create_table(
            'table_override', id=ID_PREFIX + 3,
            cols_names=['new_col1', 'col2', 'col3'],
            metric_names=['new_metric1'])
        imported_over_table = SqlaTable.import_from_dict(
            session=db.session,
            dict_rep=dict_table_over,
            sync=['metrics', 'columns'])
        db.session.commit()

        imported_over = self.get_table(imported_over_table.id)
        self.assertEquals(imported_table.id, imported_over.id)
        expected_table, _ = self.create_table(
            'table_override', id=ID_PREFIX + 3,
            metric_names=['new_metric1'],
            cols_names=['new_col1', 'col2', 'col3'])
        self.assert_table_equals(expected_table, imported_over)
        self.yaml_compare(
            expected_table.export_to_dict(),
            imported_over.export_to_dict())
Exemplo n.º 2
0
    def test_import_table_override_idential(self):
        table = self.create_table(
            'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
            metric_names=['new_metric1'])
        imported_id = SqlaTable.import_obj(table, import_time=1993)

        copy_table = self.create_table(
            'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
            metric_names=['new_metric1'])
        imported_id_copy = SqlaTable.import_obj(
            copy_table, import_time=1994)

        self.assertEquals(imported_id, imported_id_copy)
        self.assert_table_equals(copy_table, self.get_table(imported_id))
Exemplo n.º 3
0
    def test_import_table_2_col_2_met(self):
        table = self.create_table(
            'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'],
            metric_names=['m1', 'm2'])
        imported_id = SqlaTable.import_obj(table, import_time=1991)

        imported = self.get_table(imported_id)
        self.assert_table_equals(table, imported)
 def test_import_table_no_metadata(self):
     table, dict_table = self.create_table('pure_table', id=ID_PREFIX + 1)
     new_table = SqlaTable.import_from_dict(db.session, dict_table)
     db.session.commit()
     imported_id = new_table.id
     imported = self.get_table(imported_id)
     self.assert_table_equals(table, imported)
     self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
 def test_import_table_override_identical(self):
     table, dict_table = self.create_table(
         'copy_cat', id=ID_PREFIX + 4,
         cols_names=['new_col1', 'col2', 'col3'],
         metric_names=['new_metric1'])
     imported_table = SqlaTable.import_from_dict(db.session, dict_table)
     db.session.commit()
     copy_table, dict_copy_table = self.create_table(
         'copy_cat', id=ID_PREFIX + 4,
         cols_names=['new_col1', 'col2', 'col3'],
         metric_names=['new_metric1'])
     imported_copy_table = SqlaTable.import_from_dict(db.session,
                                                      dict_copy_table)
     db.session.commit()
     self.assertEquals(imported_table.id, imported_copy_table.id)
     self.assert_table_equals(copy_table, self.get_table(imported_table.id))
     self.yaml_compare(imported_copy_table.export_to_dict(),
                       imported_table.export_to_dict())
Exemplo n.º 6
0
    def test_import_table_override(self):
        table = self.create_table(
            'table_override', id=10003, cols_names=['col1'],
            metric_names=['m1'])
        imported_id = SqlaTable.import_obj(table, import_time=1991)

        table_over = self.create_table(
            'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'],
            metric_names=['new_metric1'])
        imported_over_id = SqlaTable.import_obj(
            table_over, import_time=1992)

        imported_over = self.get_table(imported_over_id)
        self.assertEquals(imported_id, imported_over.id)
        expected_table = self.create_table(
            'table_override', id=10003, metric_names=['new_metric1', 'm1'],
            cols_names=['col1', 'new_col1', 'col2', 'col3'])
        self.assert_table_equals(expected_table, imported_over)
 def test_import_table_2_col_2_met(self):
     table, dict_table = self.create_table(
         'table_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'],
         metric_names=['m1', 'm2'])
     imported_table = SqlaTable.import_from_dict(db.session, dict_table)
     db.session.commit()
     imported = self.get_table(imported_table.id)
     self.assert_table_equals(table, imported)
     self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
Exemplo n.º 8
0
 def test_import_table_1_col_1_met(self):
     table = self.create_table(
         'table_1_col_1_met', id=10002,
         cols_names=["col1"], metric_names=["metric1"])
     imported_id = SqlaTable.import_obj(table, import_time=1990)
     imported = self.get_table(imported_id)
     self.assert_table_equals(table, imported)
     self.assertEquals(
         {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'},
         json.loads(imported.params))
 def test_import_table_1_col_1_met(self):
     table, dict_table = self.create_table(
         'table_1_col_1_met', id=ID_PREFIX + 2,
         cols_names=['col1'], metric_names=['metric1'])
     imported_table = SqlaTable.import_from_dict(db.session, dict_table)
     db.session.commit()
     imported = self.get_table(imported_table.id)
     self.assert_table_equals(table, imported)
     self.assertEquals(
         {DBREF: ID_PREFIX + 2, 'database_name': 'main'},
         json.loads(imported.params))
     self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
Exemplo n.º 10
0
def test_export(session: Session) -> None:
    """
    Test exporting a dataset.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.export import ExportDatasetsCommand
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="user_id", type="INTEGER"),
        TableColumn(column_name="revenue", type="INTEGER"),
        TableColumn(column_name="expenses", type="INTEGER"),
        TableColumn(
            column_name="profit",
            type="INTEGER",
            expression="revenue-expenses",
            extra=json.dumps({"certified_by": "User"}),
        ),
    ]
    metrics = [
        SqlMetric(
            metric_name="cnt",
            expression="COUNT(*)",
            extra=json.dumps({"warning_markdown": None}),
        ),
    ]

    sqla_table = SqlaTable(
        table_name="my_table",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        database=database,
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql=None,
        params=json.dumps(
            {
                "remote_id": 64,
                "database_name": "examples",
                "import_time": 1606677834,
            }
        ),
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )

    export = list(
        ExportDatasetsCommand._export(sqla_table)  # pylint: disable=protected-access
    )
    assert export == [
        (
            "datasets/my_database/my_table.yaml",
            f"""table_name: my_table
main_dttm_col: ds
description: This is the description
default_endpoint: null
offset: -8
cache_timeout: 3600
schema: my_schema
sql: null
params:
  remote_id: 64
  database_name: examples
  import_time: 1606677834
template_params:
  answer: '42'
filter_select_enabled: 1
fetch_values_predicate: foo IN (1, 2)
extra:
  warning_markdown: '*WARNING*'
uuid: null
metrics:
- metric_name: cnt
  verbose_name: null
  metric_type: null
  expression: COUNT(*)
  description: null
  d3format: null
  extra:
    warning_markdown: null
  warning_text: null
columns:
- column_name: profit
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: revenue-expenses
  description: null
  python_date_format: null
  extra:
    certified_by: User
- column_name: ds
  verbose_name: null
  is_dttm: 1
  is_active: null
  type: TIMESTAMP
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: user_id
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: expenses
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: revenue
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
version: 1.0.0
database_uuid: {database.uuid}
""",
        ),
        (
            "databases/my_database.yaml",
            f"""database_name: my_database
sqlalchemy_uri: sqlite://
cache_timeout: null
expose_in_sqllab: true
allow_run_async: false
allow_ctas: false
allow_cvas: false
allow_file_upload: false
extra:
  metadata_params: {{}}
  engine_params: {{}}
  metadata_cache_timeout: {{}}
  schemas_allowed_for_file_upload: []
uuid: {database.uuid}
version: 1.0.0
""",
        ),
    ]
Exemplo n.º 11
0
def test_create_virtual_sqlatable(
    app_context: None,
    mocker: MockFixture,
    session: Session,
    sample_columns: Dict["TableColumn", Dict[str, Any]],
    sample_metrics: Dict["SqlMetric", Dict[str, Any]],
    columns_default: Dict[str, Any],
) -> None:
    """
    Test shadow write when creating a new ``SqlaTable``.

    When a new virtual ``SqlaTable`` is created, new models should also be created for
    ``Dataset`` and ``Column``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.columns.schemas import ColumnSchema
    from superset.connectors.sqla.models import SqlaTable
    from superset.datasets.models import Dataset
    from superset.datasets.schemas import DatasetSchema
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member
    user1 = get_test_user(1, "abc")
    physical_table_columns: List[Dict[str, Any]] = [
        dict(
            name="ds",
            is_temporal=True,
            type="TIMESTAMP",
            advanced_data_type=None,
            expression="ds",
            is_physical=True,
        ),
        dict(
            name="num_boys",
            type="INTEGER",
            advanced_data_type=None,
            expression="num_boys",
            is_physical=True,
        ),
        dict(
            name="revenue",
            type="INTEGER",
            advanced_data_type=None,
            expression="revenue",
            is_physical=True,
        ),
        dict(
            name="expenses",
            type="INTEGER",
            advanced_data_type=None,
            expression="expenses",
            is_physical=True,
        ),
    ]
    # create a physical ``Table`` that the virtual dataset points to
    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    table = Table(
        name="some_table",
        schema="my_schema",
        catalog=None,
        database=database,
        columns=[
            Column(**props, created_by=user1, changed_by=user1)
            for props in physical_table_columns
        ],
    )
    session.add(table)
    session.commit()

    assert session.query(Table).count() == 1
    assert session.query(Dataset).count() == 0

    # create virtual dataset
    columns = list(sample_columns.keys())
    metrics = list(sample_metrics.keys())
    expected_table_columns = list(sample_columns.values())
    expected_metric_columns = list(sample_metrics.values())

    sqla_table = SqlaTable(
        created_by=user1,
        changed_by=user1,
        owners=[user1],
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        default_endpoint=
        "https://www.youtube.com/watch?v=dQw4w9WgXcQ",  # not used
        database=database,
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql="""
SELECT
  ds,
  num_boys,
  revenue,
  expenses,
  revenue - expenses AS profit
FROM
  some_table""",
        params=json.dumps({
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        }),
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )
    session.add(sqla_table)
    session.flush()

    # should not add a new table
    assert session.query(Table).count() == 1
    assert session.query(Dataset).count() == 1

    # ignore these keys when comparing results
    ignored_keys = {"created_on", "changed_on"}
    column_schema = ColumnSchema()
    actual_columns = [{
        k: v
        for k, v in column_schema.dump(column).items() if k not in ignored_keys
    } for column in session.query(Column).all()]
    num_physical_columns = len(physical_table_columns)
    num_dataset_table_columns = len(columns)
    num_dataset_metric_columns = len(metrics)
    assert (len(actual_columns) == num_physical_columns +
            num_dataset_table_columns + num_dataset_metric_columns)

    for i, column in enumerate(table.columns):
        assert actual_columns[i] == {
            **columns_default,
            **physical_table_columns[i],
            "id": i + 1,
            "uuid": str(column.uuid),
            "tables": [1],
        }

    offset = num_physical_columns
    for i, column in enumerate(sqla_table.columns):
        assert actual_columns[i + offset] == {
            **columns_default,
            **expected_table_columns[i],
            "id": i + offset + 1,
            "uuid": str(column.uuid),
            "is_physical": False,
            "datasets": [1],
        }

    offset = num_physical_columns + num_dataset_table_columns
    for i, metric in enumerate(sqla_table.metrics):
        assert actual_columns[i + offset] == {
            **columns_default,
            **expected_metric_columns[i],
            "id": i + offset + 1,
            "uuid": str(metric.uuid),
            "datasets": [1],
        }

    # check that dataset was created, and has a reference to the table
    dataset_schema = DatasetSchema()
    datasets = [{
        k: v
        for k, v in dataset_schema.dump(dataset).items()
        if k not in ignored_keys
    } for dataset in session.query(Dataset).all()]
    assert len(datasets) == 1
    assert datasets[0] == {
        "id":
        1,
        "database":
        1,
        "uuid":
        str(sqla_table.uuid),
        "name":
        "old_dataset",
        "changed_by":
        1,
        "created_by":
        1,
        "owners": [1],
        "columns": [5, 6, 7, 8, 9, 10],
        "is_physical":
        False,
        "tables": [1],
        "extra_json":
        "{}",
        "external_url":
        None,
        "is_managed_externally":
        False,
        "expression":
        """
SELECT
  ds,
  num_boys,
  revenue,
  expenses,
  revenue - expenses AS profit
FROM
  some_table""",
    }
Exemplo n.º 12
0
    def form_post(self, form: CsvToDatabaseForm) -> Response:
        database = form.con.data
        csv_table = Table(table=form.name.data, schema=form.schema.data)

        if not schema_allows_csv_upload(database, csv_table.schema):
            message = _(
                'Database "%(database_name)s" schema "%(schema_name)s" '
                "is not allowed for csv uploads. Please contact your Superset Admin.",
                database_name=database.database_name,
                schema_name=csv_table.schema,
            )
            flash(message, "danger")
            return redirect("/csvtodatabaseview/form")

        if "." in csv_table.table and csv_table.schema:
            message = _(
                "You cannot specify a namespace both in the name of the table: "
                '"%(csv_table.table)s" and in the schema field: '
                '"%(csv_table.schema)s". Please remove one',
                table=csv_table.table,
                schema=csv_table.schema,
            )
            flash(message, "danger")
            return redirect("/csvtodatabaseview/form")

        try:
            df = pd.concat(
                pd.read_csv(
                    chunksize=1000,
                    encoding="utf-8",
                    filepath_or_buffer=form.csv_file.data,
                    header=form.header.data if form.header.data else 0,
                    index_col=form.index_col.data,
                    infer_datetime_format=form.infer_datetime_format.data,
                    iterator=True,
                    keep_default_na=not form.null_values.data,
                    mangle_dupe_cols=form.mangle_dupe_cols.data,
                    na_values=form.null_values.data if form.null_values.data else None,
                    nrows=form.nrows.data,
                    parse_dates=form.parse_dates.data,
                    sep=form.sep.data,
                    skip_blank_lines=form.skip_blank_lines.data,
                    skipinitialspace=form.skipinitialspace.data,
                    skiprows=form.skiprows.data,
                )
            )

            database = (
                db.session.query(models.Database)
                .filter_by(id=form.data.get("con").data.get("id"))
                .one()
            )

            database.db_engine_spec.df_to_sql(
                database,
                csv_table,
                df,
                to_sql_kwargs={
                    "chunksize": 1000,
                    "if_exists": form.if_exists.data,
                    "index": form.index.data,
                    "index_label": form.index_label.data,
                },
            )

            # Connect table to the database that should be used for exploration.
            # E.g. if hive was used to upload a csv, presto will be a better option
            # to explore the table.
            expore_database = database
            explore_database_id = database.explore_database_id
            if explore_database_id:
                expore_database = (
                    db.session.query(models.Database)
                    .filter_by(id=explore_database_id)
                    .one_or_none()
                    or database
                )

            sqla_table = (
                db.session.query(SqlaTable)
                .filter_by(
                    table_name=csv_table.table,
                    schema=csv_table.schema,
                    database_id=expore_database.id,
                )
                .one_or_none()
            )

            if sqla_table:
                sqla_table.fetch_metadata()
            if not sqla_table:
                sqla_table = SqlaTable(table_name=csv_table.table)
                sqla_table.database = expore_database
                sqla_table.database_id = database.id
                sqla_table.user_id = g.user.get_id()
                sqla_table.schema = csv_table.schema
                sqla_table.fetch_metadata()
                db.session.add(sqla_table)
            db.session.commit()
        except Exception as ex:  # pylint: disable=broad-except
            db.session.rollback()
            message = _(
                'Unable to upload CSV file "%(filename)s" to table '
                '"%(table_name)s" in database "%(db_name)s". '
                "Error message: %(error_msg)s",
                filename=form.csv_file.data.filename,
                table_name=form.name.data,
                db_name=database.database_name,
                error_msg=str(ex),
            )

            flash(message, "danger")
            stats_logger.incr("failed_csv_upload")
            return redirect("/csvtodatabaseview/form")

        # Go back to welcome page / splash screen
        message = _(
            'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in '
            'database "%(db_name)s"',
            csv_filename=form.csv_file.data.filename,
            table_name=str(csv_table),
            db_name=sqla_table.database.database_name,
        )
        flash(message, "info")
        stats_logger.incr("successful_csv_upload")
        return redirect("/tablemodelview/list/")
Exemplo n.º 13
0
def import_dataset(
    session: Session,
    config: Dict[str, Any],
    overwrite: bool = False,
    force_data: bool = False,
) -> SqlaTable:
    existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
    if existing:
        if not overwrite:
            return existing
        config["id"] = existing.id

    # TODO (betodealmeida): move this logic to import_from_dict
    config = config.copy()
    for key in JSON_KEYS:
        if config.get(key) is not None:
            try:
                config[key] = json.dumps(config[key])
            except TypeError:
                logger.info("Unable to encode `%s` field: %s", key,
                            config[key])
    for metric in config.get("metrics", []):
        if metric.get("extra") is not None:
            try:
                metric["extra"] = json.dumps(metric["extra"])
            except TypeError:
                logger.info("Unable to encode `extra` field: %s",
                            metric["extra"])
                metric["extra"] = None

    # should we delete columns and metrics not present in the current import?
    sync = ["columns", "metrics"] if overwrite else []

    # should we also load data into the dataset?
    data_uri = config.get("data")

    # import recursively to include columns and metrics
    dataset = SqlaTable.import_from_dict(session,
                                         config,
                                         recursive=True,
                                         sync=sync)
    if dataset.id is None:
        session.flush()

    example_database = get_example_database()
    try:
        table_exists = example_database.has_table_by_name(dataset.table_name)
    except Exception:  # pylint: disable=broad-except
        # MySQL doesn't play nice with GSheets table names
        logger.warning("Couldn't check if table %s exists, assuming it does",
                       dataset.table_name)
        table_exists = True

    if data_uri and (not table_exists or force_data):
        logger.info("Downloading data from %s", data_uri)
        load_data(data_uri, dataset, example_database, session)

    if hasattr(g, "user") and g.user:
        dataset.owners.append(g.user)

    return dataset
Exemplo n.º 14
0
def test_update_physical_sqlatable_metrics(
    mocker: MockFixture,
    app_context: None,
    session: Session,
    get_session: Callable[[], Session],
) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.

    For this test we check that updating the SQL expression in a metric belonging to a
    ``SqlaTable`` is reflected in the ``Dataset`` metric.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    metrics = [
        SqlMetric(metric_name="cnt", expression="COUNT(*)"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    # check that the metric was created
    # 1 physical column for table + (1 column + 1 metric for datasets)
    assert session.query(Column).count() == 3

    column = session.query(Column).filter_by(is_physical=False).one()
    assert column.expression == "COUNT(*)"

    # change the metric definition
    sqla_table.metrics[0].expression = "MAX(ds)"
    session.flush()

    assert column.expression == "MAX(ds)"

    # in a new session, update new columns and metrics at the same time
    # reload the sqla_table so we can test the case that accessing an not already
    # loaded attribute (`sqla_table.metrics`) while there are updates on the instance
    # may trigger `after_update` before the attribute is loaded
    session = get_session()
    sqla_table = session.query(SqlaTable).filter(
        SqlaTable.id == sqla_table.id).one()
    sqla_table.columns.append(
        TableColumn(
            column_name="another_column",
            is_dttm=0,
            type="TIMESTAMP",
            expression="concat('a', 'b')",
        ))
    # Here `SqlaTable.after_update` is triggered
    # before `sqla_table.metrics` is loaded
    sqla_table.metrics.append(
        SqlMetric(metric_name="another_metric", expression="COUNT(*)"))
    # `SqlaTable.after_update` will trigger again at flushing
    session.flush()
    assert session.query(Column).count() == 5
Exemplo n.º 15
0
    def test_set_perm_sqla_table(self):
        table = SqlaTable(
            schema="tmp_schema",
            table_name="tmp_perm_table",
            database=get_example_database(),
        )
        db.session.add(table)
        db.session.commit()

        stored_table = (db.session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table").one())
        self.assertEqual(stored_table.perm,
                         f"[examples].[tmp_perm_table](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # table name change
        stored_table.table_name = "tmp_perm_table_v2"
        db.session.commit()
        stored_table = (db.session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[examples].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # schema name change
        stored_table.schema = "tmp_schema_v2"
        db.session.commit()
        stored_table = (db.session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[examples].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm,
                         "[examples].[tmp_schema_v2]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # database change
        new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db")
        db.session.add(new_db)
        stored_table.database = (db.session.query(Database).filter_by(
            database_name="tmp_db").one())
        db.session.commit()
        stored_table = (db.session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # no schema
        stored_table.schema = None
        db.session.commit()
        stored_table = (db.session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        self.assertIsNone(stored_table.schema_perm)

        db.session.delete(new_db)
        db.session.delete(stored_table)
        db.session.commit()
Exemplo n.º 16
0
def create_table_permissions(table: models.SqlaTable) -> None:
    security_manager.add_permission_view_menu("datasource_access",
                                              table.get_perm())
    if table.schema:
        security_manager.add_permission_view_menu("schema_access",
                                                  table.schema_perm)
def test_update_physical_sqlatable_no_dataset(mocker: MockFixture,
                                              app_context: None,
                                              session: Session) -> None:
    """
    Test updating the table on a physical dataset that it creates
    a new dataset if one didn't already exist.

    When updating the table on a physical dataset by pointing it somewhere else (change
    in database ID, schema, or table name) we should point the ``Dataset`` to an
    existing ``Table`` if possible, and create a new one otherwise.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)
    mocker.patch("superset.datasets.dao.db.session", session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table
    from superset.tables.schemas import TableSchema

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="a", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    # check that the table was created
    table = session.query(Table).one()
    assert table.id == 1

    dataset = session.query(Dataset).one()
    assert dataset.tables == [table]

    # point ``SqlaTable`` to a different database
    new_database = Database(database_name="my_other_database",
                            sqlalchemy_uri="sqlite://")
    session.add(new_database)
    session.flush()
    sqla_table.database = new_database
    session.flush()

    new_dataset = session.query(Dataset).one()

    # check that dataset now points to the new table
    assert new_dataset.tables[0].database_id == 2

    # point ``SqlaTable`` back
    sqla_table.database_id = 1
    session.flush()

    # check that dataset points to the original table
    assert new_dataset.tables[0].database_id == 1
def test_dataset_attributes(app_context: None, session: Session) -> None:
    """
    Test that checks attributes in the dataset.

    If this check fails it means new attributes were added to ``SqlaTable``, and
    ``SqlaTable.after_insert`` should be updated to handle them!
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="user_id", type="INTEGER"),
        TableColumn(column_name="revenue", type="INTEGER"),
        TableColumn(column_name="expenses", type="INTEGER"),
        TableColumn(column_name="profit",
                    type="INTEGER",
                    expression="revenue-expenses"),
    ]
    metrics = [
        SqlMetric(metric_name="cnt", expression="COUNT(*)"),
    ]

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        default_endpoint=
        "https://www.youtube.com/watch?v=dQw4w9WgXcQ",  # not used
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql=None,
        params=json.dumps({
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        }),
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )

    session.add(sqla_table)
    session.flush()

    dataset = session.query(SqlaTable).one()
    # If this test fails because attributes changed, make sure to update
    # ``SqlaTable.after_insert`` accordingly.
    assert sorted(dataset.__dict__.keys()) == [
        "_sa_instance_state",
        "cache_timeout",
        "changed_by_fk",
        "changed_on",
        "columns",
        "created_by_fk",
        "created_on",
        "database",
        "database_id",
        "default_endpoint",
        "description",
        "external_url",
        "extra",
        "fetch_values_predicate",
        "filter_select_enabled",
        "id",
        "is_featured",
        "is_managed_externally",
        "is_sqllab_view",
        "main_dttm_col",
        "metrics",
        "offset",
        "params",
        "perm",
        "schema",
        "schema_perm",
        "sql",
        "table_name",
        "template_params",
        "uuid",
    ]
def test_update_virtual_sqlatable_references(mocker: MockFixture,
                                             app_context: None,
                                             session: Session) -> None:
    """
    Test that changing the SQL of a virtual ``SqlaTable`` updates ``Dataset``.

    When the SQL is modified the list of referenced tables should be updated in the new
    ``Dataset`` model.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    table1 = Table(
        name="table_a",
        schema="my_schema",
        catalog=None,
        database=database,
        columns=[Column(name="a", type="INTEGER")],
    )
    table2 = Table(
        name="table_b",
        schema="my_schema",
        catalog=None,
        database=database,
        columns=[Column(name="b", type="INTEGER")],
    )
    session.add(table1)
    session.add(table2)
    session.commit()

    # create virtual dataset
    columns = [TableColumn(column_name="a", type="INTEGER")]

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        database=database,
        schema="my_schema",
        sql="SELECT a FROM table_a",
    )
    session.add(sqla_table)
    session.flush()

    # check that new dataset has table1
    dataset = session.query(Dataset).one()
    assert dataset.tables == [table1]

    # change SQL
    sqla_table.sql = "SELECT a, b FROM table_a JOIN table_b"
    session.flush()

    # check that new dataset has both tables
    new_dataset = session.query(Dataset).one()
    assert new_dataset.tables == [table1, table2]
    assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
def test_update_physical_sqlatable(mocker: MockFixture, app_context: None,
                                   session: Session) -> None:
    """
    Test updating the table on a physical dataset.

    When updating the table on a physical dataset by pointing it somewhere else (change
    in database ID, schema, or table name) we should point the ``Dataset`` to an
    existing ``Table`` if possible, and create a new one otherwise.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)
    mocker.patch("superset.datasets.dao.db.session", session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table
    from superset.tables.schemas import TableSchema

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="a", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    # check that the table was created, and that the created dataset points to it
    table = session.query(Table).one()
    assert table.id == 1
    assert table.name == "old_dataset"
    assert table.schema is None
    assert table.database_id == 1

    dataset = session.query(Dataset).one()
    assert dataset.tables == [table]

    # point ``SqlaTable`` to a different database
    new_database = Database(database_name="my_other_database",
                            sqlalchemy_uri="sqlite://")
    session.add(new_database)
    session.flush()
    sqla_table.database = new_database
    session.flush()

    # ignore these keys when comparing results
    ignored_keys = {"created_on", "changed_on", "uuid"}

    # check that the old table still exists, and that the dataset points to the newly
    # created table (id=2) and column (id=2), on the new database (also id=2)
    table_schema = TableSchema()
    tables = [{
        k: v
        for k, v in table_schema.dump(table).items() if k not in ignored_keys
    } for table in session.query(Table).all()]
    assert tables == [
        {
            "created_by": None,
            "extra_json": "{}",
            "name": "old_dataset",
            "changed_by": None,
            "catalog": None,
            "columns": [1],
            "database": 1,
            "external_url": None,
            "schema": None,
            "id": 1,
            "is_managed_externally": False,
        },
        {
            "created_by": None,
            "extra_json": "{}",
            "name": "old_dataset",
            "changed_by": None,
            "catalog": None,
            "columns": [2],
            "database": 2,
            "external_url": None,
            "schema": None,
            "id": 2,
            "is_managed_externally": False,
        },
    ]

    # check that dataset now points to the new table
    assert dataset.tables[0].database_id == 2

    # point ``SqlaTable`` back
    sqla_table.database_id = 1
    session.flush()

    # check that dataset points to the original table
    assert dataset.tables[0].database_id == 1
Exemplo n.º 21
0
def test_import_dataset_duplicate_column(session: Session) -> None:
    """
    Test importing a dataset with a column that already exists.
    """
    from superset.columns.models import Column as NewColumn
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    dataset_uuid = uuid.uuid4()

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")

    session.add(database)
    session.flush()

    dataset = SqlaTable(uuid=dataset_uuid,
                        table_name="existing_dataset",
                        database_id=database.id)
    column = TableColumn(column_name="existing_column")
    session.add(dataset)
    session.add(column)
    session.flush()

    config = {
        "table_name":
        dataset.table_name,
        "main_dttm_col":
        "ds",
        "description":
        "This is the description",
        "default_endpoint":
        None,
        "offset":
        -8,
        "cache_timeout":
        3600,
        "schema":
        "my_schema",
        "sql":
        None,
        "params": {
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        },
        "template_params": {
            "answer": "42",
        },
        "filter_select_enabled":
        True,
        "fetch_values_predicate":
        "foo IN (1, 2)",
        "extra": {
            "warning_markdown": "*WARNING*"
        },
        "uuid":
        dataset_uuid,
        "metrics": [{
            "metric_name": "cnt",
            "verbose_name": None,
            "metric_type": None,
            "expression": "COUNT(*)",
            "description": None,
            "d3format": None,
            "extra": {
                "warning_markdown": None
            },
            "warning_text": None,
        }],
        "columns": [{
            "column_name": column.column_name,
            "verbose_name": None,
            "is_dttm": None,
            "is_active": None,
            "type": "INTEGER",
            "groupby": None,
            "filterable": None,
            "expression": "revenue-expenses",
            "description": None,
            "python_date_format": None,
            "extra": {
                "certified_by": "User",
            },
        }],
        "database_uuid":
        database.uuid,
        "database_id":
        database.id,
    }

    sqla_table = import_dataset(session, config, overwrite=True)
    assert sqla_table.table_name == dataset.table_name
    assert sqla_table.main_dttm_col == "ds"
    assert sqla_table.description == "This is the description"
    assert sqla_table.default_endpoint is None
    assert sqla_table.offset == -8
    assert sqla_table.cache_timeout == 3600
    assert sqla_table.schema == "my_schema"
    assert sqla_table.sql is None
    assert sqla_table.params == json.dumps({
        "remote_id": 64,
        "database_name": "examples",
        "import_time": 1606677834
    })
    assert sqla_table.template_params == json.dumps({"answer": "42"})
    assert sqla_table.filter_select_enabled is True
    assert sqla_table.fetch_values_predicate == "foo IN (1, 2)"
    assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
    assert sqla_table.uuid == dataset_uuid
    assert len(sqla_table.metrics) == 1
    assert sqla_table.metrics[0].metric_name == "cnt"
    assert sqla_table.metrics[0].verbose_name is None
    assert sqla_table.metrics[0].metric_type is None
    assert sqla_table.metrics[0].expression == "COUNT(*)"
    assert sqla_table.metrics[0].description is None
    assert sqla_table.metrics[0].d3format is None
    assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
    assert sqla_table.metrics[0].warning_text is None
    assert len(sqla_table.columns) == 1
    assert sqla_table.columns[0].column_name == column.column_name
    assert sqla_table.columns[0].verbose_name is None
    assert sqla_table.columns[0].is_dttm is False
    assert sqla_table.columns[0].is_active is True
    assert sqla_table.columns[0].type == "INTEGER"
    assert sqla_table.columns[0].groupby is True
    assert sqla_table.columns[0].filterable is True
    assert sqla_table.columns[0].expression == "revenue-expenses"
    assert sqla_table.columns[0].description is None
    assert sqla_table.columns[0].python_date_format is None
    assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
    assert sqla_table.database.uuid == database.uuid
    assert sqla_table.database.id == database.id
Exemplo n.º 22
0
 def test_import_table_no_metadata(self):
     db_id = get_example_database().id
     table = self.create_table("pure_table", id=10001)
     imported_id = SqlaTable.import_obj(table, db_id, import_time=1989)
     imported = self.get_table_by_id(imported_id)
     self.assert_table_equals(table, imported)
Exemplo n.º 23
0
    def test_set_perm_sqla_table(self):
        security_manager.on_view_menu_after_insert = Mock()
        security_manager.on_permission_view_after_insert = Mock()

        session = db.session
        table = SqlaTable(
            schema="tmp_schema",
            table_name="tmp_perm_table",
            database=get_example_database(),
        )
        session.add(table)
        session.commit()

        stored_table = (session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table").one())
        self.assertEqual(stored_table.perm,
                         f"[examples].[tmp_perm_table](id:{stored_table.id})")

        pvm_dataset = security_manager.find_permission_view_menu(
            "datasource_access", stored_table.perm)
        pvm_schema = security_manager.find_permission_view_menu(
            "schema_access", stored_table.schema_perm)

        self.assertIsNotNone(pvm_dataset)
        self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]")
        self.assertIsNotNone(pvm_schema)

        # assert on permission hooks
        view_menu_dataset = security_manager.find_view_menu(
            f"[examples].[tmp_perm_table](id:{stored_table.id})")
        view_menu_schema = security_manager.find_view_menu(
            f"[examples].[tmp_schema]")
        security_manager.on_view_menu_after_insert.assert_has_calls([
            call(ANY, ANY, view_menu_dataset),
            call(ANY, ANY, view_menu_schema),
        ])
        security_manager.on_permission_view_after_insert.assert_has_calls([
            call(ANY, ANY, pvm_dataset),
            call(ANY, ANY, pvm_schema),
        ])

        # table name change
        stored_table.table_name = "tmp_perm_table_v2"
        session.commit()
        stored_table = (session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[examples].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # schema name change
        stored_table.schema = "tmp_schema_v2"
        session.commit()
        stored_table = (session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[examples].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm,
                         "[examples].[tmp_schema_v2]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # database change
        new_db = Database(sqlalchemy_uri="sqlite://", database_name="tmp_db")
        session.add(new_db)
        stored_table.database = (session.query(Database).filter_by(
            database_name="tmp_db").one())
        session.commit()
        stored_table = (session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        # no changes in schema
        self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu(
                "schema_access", stored_table.schema_perm))

        # no schema
        stored_table.schema = None
        session.commit()
        stored_table = (session.query(SqlaTable).filter_by(
            table_name="tmp_perm_table_v2").one())
        self.assertEqual(
            stored_table.perm,
            f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})")
        self.assertIsNotNone(
            security_manager.find_permission_view_menu("datasource_access",
                                                       stored_table.perm))
        self.assertIsNone(stored_table.schema_perm)

        session.delete(new_db)
        session.delete(stored_table)
        session.commit()
Exemplo n.º 24
0
 def test_comments_in_sqlatable_query(self):
     clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
     commented_query = '/* comment 1 */' + clean_query + '-- comment 2'
     table = SqlaTable(sql=commented_query)
     rendered_query = str(table.get_from_clause())
     self.assertEqual(clean_query, rendered_query)
Exemplo n.º 25
0
    def create_table_in_view(self):
        self.test_database = utils.get_example_database()
        self.test_database.allow_dml = True

        params = {"remote_id": 1234, "database_name": self.test_database.name}
        self.test_table = SqlaTable(id=1234,
                                    table_name="Departments",
                                    params=json.dumps(params))
        self.test_table.columns.append(
            TableColumn(column_name="department_id", type="INTEGER"))
        self.test_table.columns.append(
            TableColumn(column_name="name", type="STRING"))
        self.test_table.columns.append(
            TableColumn(column_name="street", type="STRING"))
        self.test_table.columns.append(
            TableColumn(column_name="city", type="STRING"))
        self.test_table.columns.append(
            TableColumn(column_name="country", type="STRING"))
        self.test_table.columns.append(
            TableColumn(column_name="lat", type="FLOAT"))
        self.test_table.columns.append(
            TableColumn(column_name="lon", type="FLOAT"))
        self.test_table.database = self.test_database
        self.test_table.database_id = self.test_table.database.id
        db.session.add(self.test_table)
        db.session.commit()

        data = {
            "department_id": [1, 2, 3, 4, 5],
            "name": [
                "Logistics",
                "Marketing",
                "Facility Management",
                "Personal",
                "Finances",
            ],
            "street": [
                "Oberseestrasse 10",
                "Grossmünsterplatz",
                "Uetliberg",
                "Zürichbergstrasse 221",
                "Bahnhofstrasse",
            ],
            "city": ["Rapperswil", "Zürich", "Zürich", "Zürich", "Zürich"],
            "country": [
                "Switzerland",
                "Switzerland",
                "Switzerland",
                "Switzerland",
                "Switzerland",
            ],
            "lat": [None, None, None, None, 4.789],
            "lon": [None, None, None, 1.234, None],
        }
        df = pd.DataFrame(data=data)

        # because of caching problem with postgres load database a second time
        # without this, the sqla engine throws an exception
        database = utils.get_example_database()
        if database:
            engine = database.get_sqla_engine()
            df.to_sql(
                self.test_table.table_name,
                engine,
                if_exists="replace",
                chunksize=500,
                dtype={
                    "department_id": Integer,
                    "name": String(60),
                    "street": String(60),
                    "city": String(60),
                    "country": String(60),
                    "lat": Float,
                    "lon": Float,
                },
                index=False,
            )
Exemplo n.º 26
0
    def test_extra_cache_keys(self, flask_g):
        flask_g.user.username = "******"
        base_query_obj = {
            "granularity": None,
            "from_dttm": None,
            "to_dttm": None,
            "groupby": ["user"],
            "metrics": [],
            "is_timeseries": False,
            "filter": [],
        }

        # Table with Jinja callable.
        table1 = SqlaTable(
            table_name="test_has_extra_cache_keys_table",
            sql="SELECT '{{ current_username() }}' as user",
            database=get_example_database(),
        )

        query_obj = dict(**base_query_obj, extras={})
        extra_cache_keys = table1.get_extra_cache_keys(query_obj)
        self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
        assert extra_cache_keys == ["abc"]

        # Table with Jinja callable disabled.
        table2 = SqlaTable(
            table_name="test_has_extra_cache_keys_disabled_table",
            sql="SELECT '{{ current_username(False) }}' as user",
            database=get_example_database(),
        )
        query_obj = dict(**base_query_obj, extras={})
        extra_cache_keys = table2.get_extra_cache_keys(query_obj)
        self.assertTrue(table2.has_extra_cache_key_calls(query_obj))
        self.assertListEqual(extra_cache_keys, [])

        # Table with no Jinja callable.
        query = "SELECT 'abc' as user"
        table3 = SqlaTable(
            table_name="test_has_no_extra_cache_keys_table",
            sql=query,
            database=get_example_database(),
        )

        query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
        extra_cache_keys = table3.get_extra_cache_keys(query_obj)
        self.assertFalse(table3.has_extra_cache_key_calls(query_obj))
        self.assertListEqual(extra_cache_keys, [])

        # With Jinja callable in SQL expression.
        query_obj = dict(**base_query_obj,
                         extras={
                             "where": "(user != '{{ current_username() }}')"
                         })
        extra_cache_keys = table3.get_extra_cache_keys(query_obj)
        self.assertTrue(table3.has_extra_cache_key_calls(query_obj))
        assert extra_cache_keys == ["abc"]

        # Cleanup
        for table in [table1, table2, table3]:
            db.session.delete(table)
        db.session.commit()
Exemplo n.º 27
0
def test__normalize_prequery_result_type(
    app_context: Flask,
    mocker: MockFixture,
    row: pd.Series,
    dimension: str,
    result: Any,
) -> None:
    def _convert_dttm(
            target_type: str,
            dttm: datetime,
            db_extra: Optional[Dict[str, Any]] = None) -> Optional[str]:
        if target_type.upper() == TemporalType.TIMESTAMP:
            return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""

        return None

    table = SqlaTable(table_name="foobar", database=get_example_database())
    mocker.patch.object(table.db_engine_spec,
                        "convert_dttm",
                        new=_convert_dttm)

    columns_by_name = {
        "foo":
        TableColumn(
            column_name="foo",
            is_dttm=False,
            table=table,
            type="STRING",
        ),
        "bar":
        TableColumn(
            column_name="bar",
            is_dttm=False,
            table=table,
            type="BOOLEAN",
        ),
        "baz":
        TableColumn(
            column_name="baz",
            is_dttm=False,
            table=table,
            type="INTEGER",
        ),
        "qux":
        TableColumn(
            column_name="qux",
            is_dttm=False,
            table=table,
            type="FLOAT",
        ),
        "quux":
        TableColumn(
            column_name="quuz",
            is_dttm=True,
            table=table,
            type="STRING",
        ),
        "quuz":
        TableColumn(
            column_name="quux",
            is_dttm=True,
            table=table,
            type="TIMESTAMP",
        ),
    }

    normalized = table._normalize_prequery_result_type(
        row,
        dimension,
        columns_by_name,
    )

    assert type(normalized) == type(result)

    if isinstance(normalized, TextClause):
        assert str(normalized) == str(result)
    else:
        assert normalized == result
Exemplo n.º 28
0
    def invalidate(self) -> Response:
        """
        Takes a list of datasources, finds the associated cache records and
        invalidates them and removes the database records

        ---
        post:
          description: >-
            Takes a list of datasources, finds the associated cache records and
            invalidates them and removes the database records
          requestBody:
            description: >-
              A list of datasources uuid or the tuples of database and datasource names
            required: true
            content:
              application/json:
                schema:
                  $ref: "#/components/schemas/CacheInvalidationRequestSchema"
          responses:
            201:
              description: cache was successfully invalidated
            400:
              $ref: '#/components/responses/400'
            500:
              $ref: '#/components/responses/500'
        """
        try:
            datasources = CacheInvalidationRequestSchema().load(request.json)
        except KeyError:
            return self.response_400(message="Request is incorrect")
        except ValidationError as error:
            return self.response_400(message=str(error))
        datasource_uids = set(datasources.get("datasource_uids", []))
        for ds in datasources.get("datasources", []):
            ds_obj = SqlaTable.get_datasource_by_name(
                session=db.session,
                datasource_name=ds.get("datasource_name"),
                schema=ds.get("schema"),
                database_name=ds.get("database_name"),
            )

            if ds_obj:
                datasource_uids.add(ds_obj.uid)

        cache_key_objs = (db.session.query(CacheKey).filter(
            CacheKey.datasource_uid.in_(datasource_uids)).all())
        cache_keys = [c.cache_key for c in cache_key_objs]
        if cache_key_objs:
            all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)

            if not all_keys_deleted:
                # expected behavior as keys may expire and cache is not a
                # persistent storage
                logger.info(
                    "Some of the cache keys were not deleted in the list %s",
                    cache_keys)

            try:
                delete_stmt = (
                    CacheKey.__table__.delete().where(  # pylint: disable=no-member
                        CacheKey.cache_key.in_(cache_keys)))
                db.session.execute(delete_stmt)
                db.session.commit()
                self.stats_logger.gauge("invalidated_cache", len(cache_keys))
                logger.info(
                    "Invalidated %s cache records for %s datasources",
                    len(cache_keys),
                    len(datasource_uids),
                )
            except SQLAlchemyError as ex:  # pragma: no cover
                logger.error(ex, exc_info=True)
                db.session.rollback()
                return self.response_500(str(ex))
            db.session.commit()
        return self.response(201)
Exemplo n.º 29
0
def test_create_physical_sqlatable(
    app_context: None,
    session: Session,
    sample_columns: Dict["TableColumn", Dict[str, Any]],
    sample_metrics: Dict["SqlMetric", Dict[str, Any]],
    columns_default: Dict[str, Any],
) -> None:
    """
    Test shadow write when creating a new ``SqlaTable``.

    When a new physical ``SqlaTable`` is created, new models should also be created for
    ``Dataset``, ``Table``, and ``Column``.
    """
    from superset.columns.models import Column
    from superset.columns.schemas import ColumnSchema
    from superset.connectors.sqla.models import SqlaTable
    from superset.datasets.models import Dataset
    from superset.datasets.schemas import DatasetSchema
    from superset.models.core import Database
    from superset.tables.models import Table
    from superset.tables.schemas import TableSchema

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member
    user1 = get_test_user(1, "abc")
    columns = list(sample_columns.keys())
    metrics = list(sample_metrics.keys())
    expected_table_columns = list(sample_columns.values())
    expected_metric_columns = list(sample_metrics.values())

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        default_endpoint=
        "https://www.youtube.com/watch?v=dQw4w9WgXcQ",  # not used
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql=None,
        params=json.dumps({
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        }),
        created_by=user1,
        changed_by=user1,
        owners=[user1],
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )
    session.add(sqla_table)
    session.flush()

    # ignore these keys when comparing results
    ignored_keys = {"created_on", "changed_on"}

    # check that columns were created
    column_schema = ColumnSchema()
    actual_columns = [{
        k: v
        for k, v in column_schema.dump(column).items() if k not in ignored_keys
    } for column in session.query(Column).all()]
    num_physical_columns = len([
        col for col in expected_table_columns if col.get("is_physical") == True
    ])
    num_dataset_table_columns = len(columns)
    num_dataset_metric_columns = len(metrics)
    assert (len(actual_columns) == num_physical_columns +
            num_dataset_table_columns + num_dataset_metric_columns)

    # table columns are created before dataset columns are created
    offset = 0
    for i in range(num_physical_columns):
        assert actual_columns[i + offset] == {
            **columns_default,
            **expected_table_columns[i],
            "id": i + offset + 1,
            # physical columns for table have its own uuid
            "uuid": actual_columns[i + offset]["uuid"],
            "is_physical": True,
            # table columns do not have creators
            "created_by": None,
            "tables": [1],
        }

    offset += num_physical_columns
    for i, column in enumerate(sqla_table.columns):
        assert actual_columns[i + offset] == {
            **columns_default,
            **expected_table_columns[i],
            "id": i + offset + 1,
            # columns for dataset reuses the same uuid of TableColumn
            "uuid": str(column.uuid),
            "datasets": [1],
        }

    offset += num_dataset_table_columns
    for i, metric in enumerate(sqla_table.metrics):
        assert actual_columns[i + offset] == {
            **columns_default,
            **expected_metric_columns[i],
            "id": i + offset + 1,
            "uuid": str(metric.uuid),
            "datasets": [1],
        }

    # check that table was created
    table_schema = TableSchema()
    tables = [{
        k: v
        for k, v in table_schema.dump(table).items()
        if k not in (ignored_keys | {"uuid"})
    } for table in session.query(Table).all()]
    assert len(tables) == 1
    assert tables[0] == {
        "id": 1,
        "database": 1,
        "created_by": 1,
        "changed_by": 1,
        "datasets": [1],
        "columns": [1, 2, 3],
        "extra_json": "{}",
        "catalog": None,
        "schema": "my_schema",
        "name": "old_dataset",
        "is_managed_externally": False,
        "external_url": None,
    }

    # check that dataset was created
    dataset_schema = DatasetSchema()
    datasets = [{
        k: v
        for k, v in dataset_schema.dump(dataset).items()
        if k not in ignored_keys
    } for dataset in session.query(Dataset).all()]
    assert len(datasets) == 1
    assert datasets[0] == {
        "id": 1,
        "uuid": str(sqla_table.uuid),
        "created_by": 1,
        "changed_by": 1,
        "owners": [1],
        "name": "old_dataset",
        "columns": [4, 5, 6, 7, 8, 9],
        "is_physical": True,
        "database": 1,
        "tables": [1],
        "extra_json": "{}",
        "expression": "old_dataset",
        "is_managed_externally": False,
        "external_url": None,
    }
Exemplo n.º 30
0
 def test_import_table_no_metadata(self):
     table = self.create_table('pure_table', id=10001)
     imported_id = SqlaTable.import_obj(table, import_time=1989)
     imported = self.get_table(imported_id)
     self.assert_table_equals(table, imported)
Exemplo n.º 31
0
 def test_import_table_no_metadata(self):
     table = self.create_table('pure_table', id=10001)
     imported_id = SqlaTable.import_obj(table)
     imported = self.get_table(imported_id)
     self.assert_table_equals(table, imported)
Exemplo n.º 32
0
def test_update_physical_sqlatable_database(
    mocker: MockFixture,
    app_context: None,
    session: Session,
    get_session: Callable[[], Session],
) -> None:
    """
    Test updating the table on a physical dataset.

    When updating the table on a physical dataset by pointing it somewhere else (change
    in database ID, schema, or table name) we should point the ``Dataset`` to an
    existing ``Table`` if possible, and create a new one otherwise.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)
    mocker.patch("superset.datasets.dao.db.session", session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset, dataset_column_association_table
    from superset.models.core import Database
    from superset.tables.models import Table, table_column_association_table
    from superset.tables.schemas import TableSchema

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="a", type="INTEGER"),
    ]

    original_database = Database(database_name="my_database",
                                 sqlalchemy_uri="sqlite://")
    sqla_table = SqlaTable(
        table_name="original_table",
        columns=columns,
        metrics=[],
        database=original_database,
    )
    session.add(sqla_table)
    session.flush()

    assert session.query(Table).count() == 1
    assert session.query(Dataset).count() == 1
    assert session.query(Column).count() == 2  # 1 for table, 1 for dataset

    # check that the table was created, and that the created dataset points to it
    table = session.query(Table).one()
    assert table.id == 1
    assert table.name == "original_table"
    assert table.schema is None
    assert table.database_id == 1

    dataset = session.query(Dataset).one()
    assert dataset.tables == [table]

    # point ``SqlaTable`` to a different database
    new_database = Database(database_name="my_other_database",
                            sqlalchemy_uri="sqlite://")
    session.add(new_database)
    session.flush()
    sqla_table.database = new_database
    sqla_table.table_name = "new_table"
    session.flush()

    assert session.query(Dataset).count() == 1
    assert session.query(Table).count() == 2
    # <Column:id=1> is kept for the old table
    # <Column:id=2> is kept for the updated dataset
    # <Column:id=3> is created for the new table
    assert session.query(Column).count() == 3

    # ignore these keys when comparing results
    ignored_keys = {"created_on", "changed_on", "uuid"}

    # check that the old table still exists, and that the dataset points to the newly
    # created table, column and dataset
    table_schema = TableSchema()
    tables = [{
        k: v
        for k, v in table_schema.dump(table).items() if k not in ignored_keys
    } for table in session.query(Table).all()]
    assert tables[0] == {
        "id": 1,
        "database": 1,
        "columns": [1],
        "datasets": [],
        "created_by": None,
        "changed_by": None,
        "extra_json": "{}",
        "catalog": None,
        "schema": None,
        "name": "original_table",
        "external_url": None,
        "is_managed_externally": False,
    }
    assert tables[1] == {
        "id": 2,
        "database": 2,
        "datasets": [1],
        "columns": [3],
        "created_by": None,
        "changed_by": None,
        "catalog": None,
        "schema": None,
        "name": "new_table",
        "is_managed_externally": False,
        "extra_json": "{}",
        "external_url": None,
    }

    # check that dataset now points to the new table
    assert dataset.tables[0].database_id == 2
    # and a new column is created
    assert len(dataset.columns) == 1
    assert dataset.columns[0].id == 2

    # point ``SqlaTable`` back
    sqla_table.database = original_database
    sqla_table.table_name = "original_table"
    session.flush()

    # should not create more table and datasets
    assert session.query(Dataset).count() == 1
    assert session.query(Table).count() == 2
    # <Column:id=1> is deleted for the old table
    # <Column:id=2> is kept for the updated dataset
    # <Column:id=3> is kept for the new table
    assert session.query(Column.id).order_by(Column.id).all() == [
        (1, ),
        (2, ),
        (3, ),
    ]
    assert session.query(dataset_column_association_table).all() == [(1, 2)]
    assert session.query(table_column_association_table).all() == [(1, 1),
                                                                   (2, 3)]
    assert session.query(Dataset).filter_by(id=1).one().columns[0].id == 2
    assert session.query(Table).filter_by(id=2).one().columns[0].id == 3
    assert session.query(Table).filter_by(id=1).one().columns[0].id == 1

    # the dataset points back to the original table
    assert dataset.tables[0].database_id == 1
    assert dataset.tables[0].name == "original_table"

    # kept the original column
    assert dataset.columns[0].id == 2
    session.commit()
    session.close()

    # querying in a new session should still return the same result
    session = get_session()
    assert session.query(table_column_association_table).all() == [(1, 1),
                                                                   (2, 3)]
Exemplo n.º 33
0
    def form_post(self, form: CsvToDatabaseForm) -> Response:
        database = form.con.data
        csv_table = Table(table=form.name.data, schema=form.schema.data)

        if not schema_allows_csv_upload(database, csv_table.schema):
            message = _(
                'Database "%(database_name)s" schema "%(schema_name)s" '
                "is not allowed for csv uploads. Please contact your Superset Admin.",
                database_name=database.database_name,
                schema_name=csv_table.schema,
            )
            flash(message, "danger")
            return redirect("/csvtodatabaseview/form")

        if "." in csv_table.table and csv_table.schema:
            message = _(
                "You cannot specify a namespace both in the name of the table: "
                '"%(csv_table.table)s" and in the schema field: '
                '"%(csv_table.schema)s". Please remove one',
                table=csv_table.table,
                schema=csv_table.schema,
            )
            flash(message, "danger")
            return redirect("/csvtodatabaseview/form")

        uploaded_tmp_file_path = tempfile.NamedTemporaryFile(
            dir=app.config["UPLOAD_FOLDER"],
            suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(),
            delete=False,
        ).name

        try:
            utils.ensure_path_exists(config["UPLOAD_FOLDER"])
            upload_stream_write(form.csv_file.data, uploaded_tmp_file_path)

            con = form.data.get("con")
            database = (
                db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
            )

            # More can be found here:
            # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
            csv_to_df_kwargs = {
                "sep": form.sep.data,
                "header": form.header.data if form.header.data else 0,
                "index_col": form.index_col.data,
                "mangle_dupe_cols": form.mangle_dupe_cols.data,
                "skipinitialspace": form.skipinitialspace.data,
                "skiprows": form.skiprows.data,
                "nrows": form.nrows.data,
                "skip_blank_lines": form.skip_blank_lines.data,
                "parse_dates": form.parse_dates.data,
                "infer_datetime_format": form.infer_datetime_format.data,
                "chunksize": 1000,
            }
            if form.null_values.data:
                csv_to_df_kwargs["na_values"] = form.null_values.data
                csv_to_df_kwargs["keep_default_na"] = False

            # More can be found here:
            # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html
            df_to_sql_kwargs = {
                "name": csv_table.table,
                "if_exists": form.if_exists.data,
                "index": form.index.data,
                "index_label": form.index_label.data,
                "chunksize": 1000,
            }
            database.db_engine_spec.create_table_from_csv(
                uploaded_tmp_file_path,
                csv_table,
                database,
                csv_to_df_kwargs,
                df_to_sql_kwargs,
            )

            # Connect table to the database that should be used for exploration.
            # E.g. if hive was used to upload a csv, presto will be a better option
            # to explore the table.
            expore_database = database
            explore_database_id = database.explore_database_id
            if explore_database_id:
                expore_database = (
                    db.session.query(models.Database)
                    .filter_by(id=explore_database_id)
                    .one_or_none()
                    or database
                )

            sqla_table = (
                db.session.query(SqlaTable)
                .filter_by(
                    table_name=csv_table.table,
                    schema=csv_table.schema,
                    database_id=expore_database.id,
                )
                .one_or_none()
            )

            if sqla_table:
                sqla_table.fetch_metadata()
            if not sqla_table:
                sqla_table = SqlaTable(table_name=csv_table.table)
                sqla_table.database = expore_database
                sqla_table.database_id = database.id
                sqla_table.user_id = g.user.id
                sqla_table.schema = csv_table.schema
                sqla_table.fetch_metadata()
                db.session.add(sqla_table)
            db.session.commit()
        except Exception as ex:  # pylint: disable=broad-except
            db.session.rollback()
            try:
                os.remove(uploaded_tmp_file_path)
            except OSError:
                pass
            message = _(
                'Unable to upload CSV file "%(filename)s" to table '
                '"%(table_name)s" in database "%(db_name)s". '
                "Error message: %(error_msg)s",
                filename=form.csv_file.data.filename,
                table_name=form.name.data,
                db_name=database.database_name,
                error_msg=str(ex),
            )

            flash(message, "danger")
            stats_logger.incr("failed_csv_upload")
            return redirect("/csvtodatabaseview/form")

        os.remove(uploaded_tmp_file_path)
        # Go back to welcome page / splash screen
        message = _(
            'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in '
            'database "%(db_name)s"',
            csv_filename=form.csv_file.data.filename,
            table_name=str(csv_table),
            db_name=sqla_table.database.database_name,
        )
        flash(message, "info")
        stats_logger.incr("successful_csv_upload")
        return redirect("/tablemodelview/list/")
Exemplo n.º 34
0
def test_update_physical_sqlatable_columns(mocker: MockFixture,
                                           session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )

    session.add(sqla_table)
    session.flush()

    assert session.query(Table).count() == 1
    assert session.query(Dataset).count() == 1
    assert session.query(Column).count() == 2  # 1 for table, 1 for dataset

    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # add a column to the original ``SqlaTable`` instance
    sqla_table.columns.append(
        TableColumn(column_name="num_boys", type="INTEGER"))
    session.flush()

    assert session.query(Column).count() == 3
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 2

    # check that both lists have the same uuids
    assert [col.uuid for col in sqla_table.columns
            ].sort() == [col.uuid for col in dataset.columns].sort()

    # delete the column in the original instance
    sqla_table.columns = sqla_table.columns[1:]
    session.flush()

    # check that the column was added to the dataset and the added columns have
    # the correct uuid.
    assert session.query(TableColumn).count() == 1
    # the extra Dataset.column is deleted, but Table.column is kept
    assert session.query(Column).count() == 2

    # check that the column was also removed from the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # modify the attribute in a column
    sqla_table.columns[0].is_dttm = True
    session.flush()

    # check that the dataset column was modified
    dataset = session.query(Dataset).one()
    assert dataset.columns[0].is_temporal is True
Exemplo n.º 35
0
    def form_post(self, form):
        database = form.con.data
        excel_table = Table(table=form.name.data, schema=form.schema.data)

        if not self.is_schema_allowed(database, excel_table.schema):
            message = _(
                'Database "%(database_name)s" schema "%(schema_name)s" '
                "is not allowed for excel uploads. Please contact your Superset Admin.",
                database_name=database.database_name,
                schema_name=excel_table.schema,
            )
            flash(message, "danger")
            return redirect("/exceltodatabaseview/form")

        if "." in excel_table.table and excel_table.schema:
            message = _(
                "You cannot specify a namespace both in the name of the table: "
                '"%(excel_table.table)s" and in the schema field: '
                '"%(excel_table.schema)s". Please remove one',
                table=excel_table.table,
                schema=excel_table.schema,
            )
            flash(message, "danger")
            return redirect("/exceltodatabaseview/form")

        uploaded_tmp_file_path = tempfile.NamedTemporaryFile(
            dir=app.config["UPLOAD_FOLDER"],
            suffix=os.path.splitext(form.excel_file.data.filename)[1].lower(),
            delete=False,
        ).name

        try:
            utils.ensure_path_exists(config["UPLOAD_FOLDER"])
            upload_stream_write(form.excel_file.data, uploaded_tmp_file_path)

            con = form.data.get("con")
            database = (
                db.session.query(models.Database).filter_by(id=con.data.get("id")).one()
            )
            excel_to_df_kwargs = {
                "header": form.header.data if form.header.data else 0,
                "index_col": form.index_col.data,
                "mangle_dupe_cols": form.mangle_dupe_cols.data,
                "skipinitialspace": form.skipinitialspace.data,
                "skiprows": form.skiprows.data,
                "nrows": form.nrows.data,
                "sheet_name": form.sheet_name.data,
                "chunksize": 1000,
            }
            df_to_sql_kwargs = {
                "name": excel_table.table,
                "if_exists": form.if_exists.data,
                "index": form.index.data,
                "index_label": form.index_label.data,
                "chunksize": 1000,
            }
            database.db_engine_spec.create_table_from_excel(
                uploaded_tmp_file_path,
                excel_table,
                database,
                excel_to_df_kwargs,
                df_to_sql_kwargs,
            )

            # Connect table to the database that should be used for exploration.
            # E.g. if hive was used to upload a excel, presto will be a better option
            # to explore the table.
            expore_database = database
            explore_database_id = database.get_extra().get("explore_database_id", None)
            if explore_database_id:
                expore_database = (
                    db.session.query(models.Database)
                    .filter_by(id=explore_database_id)
                    .one_or_none()
                    or database
                )

            sqla_table = (
                db.session.query(SqlaTable)
                .filter_by(
                    table_name=excel_table.table,
                    schema=excel_table.schema,
                    database_id=expore_database.id,
                )
                .one_or_none()
            )

            if sqla_table:
                sqla_table.fetch_metadata()
            if not sqla_table:
                sqla_table = SqlaTable(table_name=excel_table.table)
                sqla_table.database = expore_database
                sqla_table.database_id = database.id
                sqla_table.user_id = g.user.id
                sqla_table.schema = excel_table.schema
                sqla_table.fetch_metadata()
                db.session.add(sqla_table)
            db.session.commit()
        except Exception as ex:  # pylint: disable=broad-except
            db.session.rollback()
            try:
                os.remove(uploaded_tmp_file_path)
            except OSError:
                pass
            message = _(
                'Unable to upload Excel file "%(filename)s" to table '
                '"%(table_name)s" in database "%(db_name)s". '
                "Error message: %(error_msg)s",
                filename=form.excel_file.data.filename,
                table_name=form.name.data,
                db_name=database.database_name,
                error_msg=str(ex),
            )

            flash(message, "danger")
            stats_logger.incr("failed_excel_upload")
            return redirect("/exceltodatabaseview/form")

        os.remove(uploaded_tmp_file_path)
        # Go back to welcome page / splash screen
        message = _(
            'CSV file "%(excel_filename)s" uploaded to table "%(table_name)s" in '
            'database "%(db_name)s"',
            excel_filename=form.excel_file.data.filename,
            table_name=str(excel_table),
            db_name=sqla_table.database.database_name,
        )
        flash(message, "info")
        stats_logger.incr("successful_excel_upload")
        return redirect("/tablemodelview/list/")
def test_create_virtual_sqlatable(mocker: MockFixture, app_context: None,
                                  session: Session) -> None:
    """
    Test shadow write when creating a new ``SqlaTable``.

    When a new virtual ``SqlaTable`` is created, new models should also be created for
    ``Dataset`` and ``Column``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.columns.schemas import ColumnSchema
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.models import Dataset
    from superset.datasets.schemas import DatasetSchema
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    # create the ``Table`` that the virtual dataset points to
    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    table = Table(
        name="some_table",
        schema="my_schema",
        catalog=None,
        database=database,
        columns=[
            Column(name="ds", is_temporal=True, type="TIMESTAMP"),
            Column(name="user_id", type="INTEGER"),
            Column(name="revenue", type="INTEGER"),
            Column(name="expenses", type="INTEGER"),
        ],
    )
    session.add(table)
    session.commit()

    # create virtual dataset
    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="user_id", type="INTEGER"),
        TableColumn(column_name="revenue", type="INTEGER"),
        TableColumn(column_name="expenses", type="INTEGER"),
        TableColumn(column_name="profit",
                    type="INTEGER",
                    expression="revenue-expenses"),
    ]
    metrics = [
        SqlMetric(metric_name="cnt", expression="COUNT(*)"),
    ]

    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        default_endpoint=
        "https://www.youtube.com/watch?v=dQw4w9WgXcQ",  # not used
        database=database,
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql="""
SELECT
  ds,
  user_id,
  revenue,
  expenses,
  revenue - expenses AS profit
FROM
  some_table""",
        params=json.dumps({
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        }),
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )
    session.add(sqla_table)
    session.flush()

    # ignore these keys when comparing results
    ignored_keys = {"created_on", "changed_on", "uuid"}

    # check that columns were created
    column_schema = ColumnSchema()
    column_schemas = [{
        k: v
        for k, v in column_schema.dump(column).items() if k not in ignored_keys
    } for column in session.query(Column).all()]
    assert column_schemas == [
        {
            "type": "TIMESTAMP",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": None,
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "ds",
            "is_physical": True,
            "changed_by": None,
            "is_temporal": True,
            "id": 1,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": None,
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "user_id",
            "is_physical": True,
            "changed_by": None,
            "is_temporal": False,
            "id": 2,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": None,
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "revenue",
            "is_physical": True,
            "changed_by": None,
            "is_temporal": False,
            "id": 3,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": None,
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "expenses",
            "is_physical": True,
            "changed_by": None,
            "is_temporal": False,
            "id": 4,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "TIMESTAMP",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "ds",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "ds",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": True,
            "id": 5,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "user_id",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "user_id",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": False,
            "id": 6,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "revenue",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "revenue",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": False,
            "id": 7,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "expenses",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "expenses",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": False,
            "id": 8,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "INTEGER",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "revenue-expenses",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "profit",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": False,
            "id": 9,
            "is_aggregation": False,
            "external_url": None,
            "is_managed_externally": False,
        },
        {
            "type": "Unknown",
            "is_additive": False,
            "extra_json": "{}",
            "is_partition": False,
            "expression": "COUNT(*)",
            "unit": None,
            "warning_text": None,
            "created_by": None,
            "is_increase_desired": True,
            "description": None,
            "is_spatial": False,
            "name": "cnt",
            "is_physical": False,
            "changed_by": None,
            "is_temporal": False,
            "id": 10,
            "is_aggregation": True,
            "external_url": None,
            "is_managed_externally": False,
        },
    ]

    # check that dataset was created, and has a reference to the table
    dataset_schema = DatasetSchema()
    datasets = [{
        k: v
        for k, v in dataset_schema.dump(dataset).items()
        if k not in ignored_keys
    } for dataset in session.query(Dataset).all()]
    assert datasets == [{
        "id":
        1,
        "sqlatable_id":
        1,
        "name":
        "old_dataset",
        "changed_by":
        None,
        "created_by":
        None,
        "columns": [5, 6, 7, 8, 9, 10],
        "is_physical":
        False,
        "tables": [1],
        "extra_json":
        "{}",
        "external_url":
        None,
        "is_managed_externally":
        False,
        "expression":
        """
SELECT
  ds,
  user_id,
  revenue,
  expenses,
  revenue - expenses AS profit
FROM
  some_table""",
    }]
Exemplo n.º 37
0
def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None:
    datasource.main_dttm_col = "ds"
    datasource.database = database
    datasource.filter_select_enabled = True
    datasource.fetch_metadata()
Exemplo n.º 38
0
def test_sql_lab_insert_rls(
    mocker: MockerFixture,
    session: Session,
) -> None:
    """
    Integration test for `insert_rls`.
    """
    from flask_appbuilder.security.sqla.models import Role, User

    from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
    from superset.models.core import Database
    from superset.models.sql_lab import Query
    from superset.security.manager import SupersetSecurityManager
    from superset.sql_lab import execute_sql_statement
    from superset.utils.core import RowLevelSecurityFilterType

    engine = session.connection().engine
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    connection = engine.raw_connection()
    connection.execute("CREATE TABLE t (c INTEGER)")
    for i in range(10):
        connection.execute("INSERT INTO t VALUES (?)", (i,))

    cursor = connection.cursor()

    query = Query(
        sql="SELECT c FROM t",
        client_id="abcde",
        database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
        schema=None,
        limit=5,
        select_as_cta_used=False,
    )
    session.add(query)
    session.commit()

    admin = User(
        first_name="Alice",
        last_name="Doe",
        email="*****@*****.**",
        username="******",
        roles=[Role(name="Admin")],
    )

    # first without RLS
    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (
        superset_result_set.to_pandas_df().to_markdown()
        == """
|    |   c |
|---:|----:|
|  0 |   0 |
|  1 |   1 |
|  2 |   2 |
|  3 |   3 |
|  4 |   4 |""".strip()
    )
    assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"

    # now with RLS
    rls = RowLevelSecurityFilter(
        name="sqllab_rls1",
        filter_type=RowLevelSecurityFilterType.REGULAR,
        tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
        roles=[admin.roles[0]],
        group_key=None,
        clause="c > 5",
    )
    session.add(rls)
    session.flush()
    mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
    mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (
        superset_result_set.to_pandas_df().to_markdown()
        == """
|    |   c |
|---:|----:|
|  0 |   6 |
|  1 |   7 |
|  2 |   8 |
|  3 |   9 |""".strip()
    )
    assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
Exemplo n.º 39
0
    def raise_for_access(
        # pylint: disable=too-many-arguments,too-many-branches,
        # pylint: disable=too-many-locals
        self,
        database: Optional["Database"] = None,
        datasource: Optional["BaseDatasource"] = None,
        query: Optional["Query"] = None,
        query_context: Optional["QueryContext"] = None,
        table: Optional["Table"] = None,
        viz: Optional["BaseViz"] = None,
    ) -> None:
        """
        Raise an exception if the user cannot access the resource.

        :param database: The Superset database
        :param datasource: The Superset datasource
        :param query: The SQL Lab query
        :param query_context: The query context
        :param table: The Superset table (requires database)
        :param viz: The visualization
        :raises SupersetSecurityException: If the user cannot access the resource
        """

        from superset.connectors.sqla.models import SqlaTable
        from superset.sql_parse import Table

        if database and table or query:
            if query:
                database = query.database

            database = cast("Database", database)

            if self.can_access_database(database):
                return

            if query:
                tables = {
                    Table(table_.table, table_.schema or query.schema)
                    for table_ in sql_parse.ParsedQuery(query.sql).tables
                }
            elif table:
                tables = {table}

            denied = set()

            for table_ in tables:
                schema_perm = self.get_schema_perm(database,
                                                   schema=table_.schema)

                if not (schema_perm
                        and self.can_access("schema_access", schema_perm)):
                    datasources = SqlaTable.query_datasources_by_name(
                        self.get_session,
                        database,
                        table_.table,
                        schema=table_.schema)

                    # Access to any datasource is suffice.
                    for datasource_ in datasources:
                        if self.can_access("datasource_access",
                                           datasource_.perm):
                            break
                    else:
                        denied.add(table_)

            if denied:
                raise SupersetSecurityException(
                    self.get_table_access_error_object(denied))

        if datasource or query_context or viz:
            if query_context:
                datasource = query_context.datasource
            elif viz:
                datasource = viz.datasource

            assert datasource

            from superset.extensions import feature_flag_manager

            if not (self.can_access_schema(datasource) or self.can_access(
                    "datasource_access", datasource.perm or "") or
                    (feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC")
                     and self.can_access_based_on_dashboard(datasource))):
                raise SupersetSecurityException(
                    self.get_datasource_access_error_object(datasource))
Exemplo n.º 40
0
 def test_comments_in_sqlatable_query(self):
     clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
     commented_query = '/* comment 1 */' + clean_query + '-- comment 2'
     table = SqlaTable(sql=commented_query)
     rendered_query = text_type(table.get_from_clause())
     self.assertEqual(clean_query, rendered_query)
Exemplo n.º 41
0
    def export_dashboards(  # pylint: disable=too-many-locals
            cls, dashboard_ids: List[int]) -> str:
        copied_dashboards = []
        datasource_ids = set()
        for dashboard_id in dashboard_ids:
            # make sure that dashboard_id is an integer
            dashboard_id = int(dashboard_id)
            dashboard = (db.session.query(Dashboard).options(
                subqueryload(
                    Dashboard.slices)).filter_by(id=dashboard_id).first())
            # remove ids and relations (like owners, created by, slices, ...)
            copied_dashboard = dashboard.copy()
            for slc in dashboard.slices:
                datasource_ids.add((slc.datasource_id, slc.datasource_type))
                copied_slc = slc.copy()
                # save original id into json
                # we need it to update dashboard's json metadata on import
                copied_slc.id = slc.id
                # add extra params for the import
                copied_slc.alter_params(
                    remote_id=slc.id,
                    datasource_name=slc.datasource.datasource_name,
                    schema=slc.datasource.schema,
                    database_name=slc.datasource.database.name,
                )
                # set slices without creating ORM relations
                slices = copied_dashboard.__dict__.setdefault("slices", [])
                slices.append(copied_slc)

            json_metadata = json.loads(dashboard.json_metadata)
            native_filter_configuration: List[Dict[
                str, Any]] = json_metadata.get("native_filter_configuration",
                                               [])
            for native_filter in native_filter_configuration:
                session = db.session()
                for target in native_filter.get("targets", []):
                    id_ = target.get("datasetId")
                    if id_ is None:
                        continue
                    datasource = DatasourceDAO.get_datasource(
                        session, utils.DatasourceType.TABLE, id_)
                    datasource_ids.add((datasource.id, datasource.type))

            copied_dashboard.alter_params(remote_id=dashboard_id)
            copied_dashboards.append(copied_dashboard)

        eager_datasources = []
        for datasource_id, _ in datasource_ids:
            eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
                db.session, datasource_id)
            copied_datasource = eager_datasource.copy()
            copied_datasource.alter_params(
                remote_id=eager_datasource.id,
                database_name=eager_datasource.database.name,
            )
            datasource_class = copied_datasource.__class__
            for field_name in datasource_class.export_children:
                field_val = getattr(eager_datasource, field_name).copy()
                # set children without creating ORM relations
                copied_datasource.__dict__[field_name] = field_val
            eager_datasources.append(copied_datasource)

        return json.dumps(
            {
                "dashboards": copied_dashboards,
                "datasources": eager_datasources
            },
            cls=utils.DashboardEncoder,
            indent=4,
        )