def test_import_table_override_sync(self): table, dict_table = self.create_table( 'table_override', id=ID_PREFIX + 3, cols_names=['col1'], metric_names=['m1']) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() table_over, dict_table_over = self.create_table( 'table_override', id=ID_PREFIX + 3, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_over_table = SqlaTable.import_from_dict( session=db.session, dict_rep=dict_table_over, sync=['metrics', 'columns']) db.session.commit() imported_over = self.get_table(imported_over_table.id) self.assertEquals(imported_table.id, imported_over.id) expected_table, _ = self.create_table( 'table_override', id=ID_PREFIX + 3, metric_names=['new_metric1'], cols_names=['new_col1', 'col2', 'col3']) self.assert_table_equals(expected_table, imported_over) self.yaml_compare( expected_table.export_to_dict(), imported_over.export_to_dict())
def test_import_table_override_idential(self): table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_id = SqlaTable.import_obj(table, import_time=1993) copy_table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_id_copy = SqlaTable.import_obj( copy_table, import_time=1994) self.assertEquals(imported_id, imported_id_copy) self.assert_table_equals(copy_table, self.get_table(imported_id))
def test_import_table_2_col_2_met(self): table = self.create_table( 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) imported_id = SqlaTable.import_obj(table, import_time=1991) imported = self.get_table(imported_id) self.assert_table_equals(table, imported)
def test_import_table_no_metadata(self): table, dict_table = self.create_table('pure_table', id=ID_PREFIX + 1) new_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported_id = new_table.id imported = self.get_table(imported_id) self.assert_table_equals(table, imported) self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
def test_import_table_override_identical(self): table, dict_table = self.create_table( 'copy_cat', id=ID_PREFIX + 4, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() copy_table, dict_copy_table = self.create_table( 'copy_cat', id=ID_PREFIX + 4, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table) db.session.commit() self.assertEquals(imported_table.id, imported_copy_table.id) self.assert_table_equals(copy_table, self.get_table(imported_table.id)) self.yaml_compare(imported_copy_table.export_to_dict(), imported_table.export_to_dict())
def test_import_table_override(self): table = self.create_table( 'table_override', id=10003, cols_names=['col1'], metric_names=['m1']) imported_id = SqlaTable.import_obj(table, import_time=1991) table_over = self.create_table( 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_over_id = SqlaTable.import_obj( table_over, import_time=1992) imported_over = self.get_table(imported_over_id) self.assertEquals(imported_id, imported_over.id) expected_table = self.create_table( 'table_override', id=10003, metric_names=['new_metric1', 'm1'], cols_names=['col1', 'new_col1', 'col2', 'col3']) self.assert_table_equals(expected_table, imported_over)
def test_import_table_2_col_2_met(self): table, dict_table = self.create_table( 'table_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported = self.get_table(imported_table.id) self.assert_table_equals(table, imported) self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
def test_import_table_1_col_1_met(self): table = self.create_table( 'table_1_col_1_met', id=10002, cols_names=["col1"], metric_names=["metric1"]) imported_id = SqlaTable.import_obj(table, import_time=1990) imported = self.get_table(imported_id) self.assert_table_equals(table, imported) self.assertEquals( {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, json.loads(imported.params))
def test_import_table_1_col_1_met(self): table, dict_table = self.create_table( 'table_1_col_1_met', id=ID_PREFIX + 2, cols_names=['col1'], metric_names=['metric1']) imported_table = SqlaTable.import_from_dict(db.session, dict_table) db.session.commit() imported = self.get_table(imported_table.id) self.assert_table_equals(table, imported) self.assertEquals( {DBREF: ID_PREFIX + 2, 'database_name': 'main'}, json.loads(imported.params)) self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
def test_export(session: Session) -> None: """ Test exporting a dataset. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.export import ExportDatasetsCommand from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="user_id", type="INTEGER"), TableColumn(column_name="revenue", type="INTEGER"), TableColumn(column_name="expenses", type="INTEGER"), TableColumn( column_name="profit", type="INTEGER", expression="revenue-expenses", extra=json.dumps({"certified_by": "User"}), ), ] metrics = [ SqlMetric( metric_name="cnt", expression="COUNT(*)", extra=json.dumps({"warning_markdown": None}), ), ] sqla_table = SqlaTable( table_name="my_table", columns=columns, metrics=metrics, main_dttm_col="ds", database=database, offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=None, params=json.dumps( { "remote_id": 64, "database_name": "examples", "import_time": 1606677834, } ), perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) export = list( ExportDatasetsCommand._export(sqla_table) # pylint: disable=protected-access ) assert export == [ ( "datasets/my_database/my_table.yaml", f"""table_name: my_table main_dttm_col: ds description: This is the description default_endpoint: null offset: -8 cache_timeout: 3600 schema: my_schema sql: null params: remote_id: 64 database_name: examples import_time: 1606677834 template_params: answer: '42' filter_select_enabled: 1 fetch_values_predicate: foo IN (1, 2) extra: warning_markdown: '*WARNING*' uuid: null metrics: - metric_name: cnt verbose_name: null metric_type: null expression: COUNT(*) description: null d3format: null extra: warning_markdown: null warning_text: null columns: - column_name: profit verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: revenue-expenses description: null python_date_format: null extra: certified_by: User - column_name: ds verbose_name: null is_dttm: 1 is_active: null type: TIMESTAMP advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: user_id verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: expenses verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: revenue verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null version: 1.0.0 database_uuid: {database.uuid} """, ), ( "databases/my_database.yaml", f"""database_name: my_database sqlalchemy_uri: sqlite:// cache_timeout: null expose_in_sqllab: true allow_run_async: false allow_ctas: false allow_cvas: false allow_file_upload: false extra: metadata_params: {{}} engine_params: {{}} metadata_cache_timeout: {{}} schemas_allowed_for_file_upload: [] uuid: {database.uuid} version: 1.0.0 """, ), ]
def test_create_virtual_sqlatable( app_context: None, mocker: MockFixture, session: Session, sample_columns: Dict["TableColumn", Dict[str, Any]], sample_metrics: Dict["SqlMetric", Dict[str, Any]], columns_default: Dict[str, Any], ) -> None: """ Test shadow write when creating a new ``SqlaTable``. When a new virtual ``SqlaTable`` is created, new models should also be created for ``Dataset`` and ``Column``. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.columns.schemas import ColumnSchema from superset.connectors.sqla.models import SqlaTable from superset.datasets.models import Dataset from superset.datasets.schemas import DatasetSchema from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member user1 = get_test_user(1, "abc") physical_table_columns: List[Dict[str, Any]] = [ dict( name="ds", is_temporal=True, type="TIMESTAMP", advanced_data_type=None, expression="ds", is_physical=True, ), dict( name="num_boys", type="INTEGER", advanced_data_type=None, expression="num_boys", is_physical=True, ), dict( name="revenue", type="INTEGER", advanced_data_type=None, expression="revenue", is_physical=True, ), dict( name="expenses", type="INTEGER", advanced_data_type=None, expression="expenses", is_physical=True, ), ] # create a physical ``Table`` that the virtual dataset points to database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") table = Table( name="some_table", schema="my_schema", catalog=None, database=database, columns=[ Column(**props, created_by=user1, changed_by=user1) for props in physical_table_columns ], ) session.add(table) session.commit() assert session.query(Table).count() == 1 assert session.query(Dataset).count() == 0 # create virtual dataset columns = list(sample_columns.keys()) metrics = list(sample_metrics.keys()) expected_table_columns = list(sample_columns.values()) expected_metric_columns = list(sample_metrics.values()) sqla_table = SqlaTable( created_by=user1, changed_by=user1, owners=[user1], table_name="old_dataset", columns=columns, metrics=metrics, main_dttm_col="ds", default_endpoint= "https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used database=database, offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=""" SELECT ds, num_boys, revenue, expenses, revenue - expenses AS profit FROM some_table""", params=json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }), perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) session.add(sqla_table) session.flush() # should not add a new table assert session.query(Table).count() == 1 assert session.query(Dataset).count() == 1 # ignore these keys when comparing results ignored_keys = {"created_on", "changed_on"} column_schema = ColumnSchema() actual_columns = [{ k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys } for column in session.query(Column).all()] num_physical_columns = len(physical_table_columns) num_dataset_table_columns = len(columns) num_dataset_metric_columns = len(metrics) assert (len(actual_columns) == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns) for i, column in enumerate(table.columns): assert actual_columns[i] == { **columns_default, **physical_table_columns[i], "id": i + 1, "uuid": str(column.uuid), "tables": [1], } offset = num_physical_columns for i, column in enumerate(sqla_table.columns): assert actual_columns[i + offset] == { **columns_default, **expected_table_columns[i], "id": i + offset + 1, "uuid": str(column.uuid), "is_physical": False, "datasets": [1], } offset = num_physical_columns + num_dataset_table_columns for i, metric in enumerate(sqla_table.metrics): assert actual_columns[i + offset] == { **columns_default, **expected_metric_columns[i], "id": i + offset + 1, "uuid": str(metric.uuid), "datasets": [1], } # check that dataset was created, and has a reference to the table dataset_schema = DatasetSchema() datasets = [{ k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys } for dataset in session.query(Dataset).all()] assert len(datasets) == 1 assert datasets[0] == { "id": 1, "database": 1, "uuid": str(sqla_table.uuid), "name": "old_dataset", "changed_by": 1, "created_by": 1, "owners": [1], "columns": [5, 6, 7, 8, 9, 10], "is_physical": False, "tables": [1], "extra_json": "{}", "external_url": None, "is_managed_externally": False, "expression": """ SELECT ds, num_boys, revenue, expenses, revenue - expenses AS profit FROM some_table""", }
def form_post(self, form: CsvToDatabaseForm) -> Response: database = form.con.data csv_table = Table(table=form.name.data, schema=form.schema.data) if not schema_allows_csv_upload(database, csv_table.schema): message = _( 'Database "%(database_name)s" schema "%(schema_name)s" ' "is not allowed for csv uploads. Please contact your Superset Admin.", database_name=database.database_name, schema_name=csv_table.schema, ) flash(message, "danger") return redirect("/csvtodatabaseview/form") if "." in csv_table.table and csv_table.schema: message = _( "You cannot specify a namespace both in the name of the table: " '"%(csv_table.table)s" and in the schema field: ' '"%(csv_table.schema)s". Please remove one', table=csv_table.table, schema=csv_table.schema, ) flash(message, "danger") return redirect("/csvtodatabaseview/form") try: df = pd.concat( pd.read_csv( chunksize=1000, encoding="utf-8", filepath_or_buffer=form.csv_file.data, header=form.header.data if form.header.data else 0, index_col=form.index_col.data, infer_datetime_format=form.infer_datetime_format.data, iterator=True, keep_default_na=not form.null_values.data, mangle_dupe_cols=form.mangle_dupe_cols.data, na_values=form.null_values.data if form.null_values.data else None, nrows=form.nrows.data, parse_dates=form.parse_dates.data, sep=form.sep.data, skip_blank_lines=form.skip_blank_lines.data, skipinitialspace=form.skipinitialspace.data, skiprows=form.skiprows.data, ) ) database = ( db.session.query(models.Database) .filter_by(id=form.data.get("con").data.get("id")) .one() ) database.db_engine_spec.df_to_sql( database, csv_table, df, to_sql_kwargs={ "chunksize": 1000, "if_exists": form.if_exists.data, "index": form.index.data, "index_label": form.index_label.data, }, ) # Connect table to the database that should be used for exploration. # E.g. if hive was used to upload a csv, presto will be a better option # to explore the table. expore_database = database explore_database_id = database.explore_database_id if explore_database_id: expore_database = ( db.session.query(models.Database) .filter_by(id=explore_database_id) .one_or_none() or database ) sqla_table = ( db.session.query(SqlaTable) .filter_by( table_name=csv_table.table, schema=csv_table.schema, database_id=expore_database.id, ) .one_or_none() ) if sqla_table: sqla_table.fetch_metadata() if not sqla_table: sqla_table = SqlaTable(table_name=csv_table.table) sqla_table.database = expore_database sqla_table.database_id = database.id sqla_table.user_id = g.user.get_id() sqla_table.schema = csv_table.schema sqla_table.fetch_metadata() db.session.add(sqla_table) db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() message = _( 'Unable to upload CSV file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' "Error message: %(error_msg)s", filename=form.csv_file.data.filename, table_name=form.name.data, db_name=database.database_name, error_msg=str(ex), ) flash(message, "danger") stats_logger.incr("failed_csv_upload") return redirect("/csvtodatabaseview/form") # Go back to welcome page / splash screen message = _( 'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in ' 'database "%(db_name)s"', csv_filename=form.csv_file.data.filename, table_name=str(csv_table), db_name=sqla_table.database.database_name, ) flash(message, "info") stats_logger.incr("successful_csv_upload") return redirect("/tablemodelview/list/")
def import_dataset( session: Session, config: Dict[str, Any], overwrite: bool = False, force_data: bool = False, ) -> SqlaTable: existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite: return existing config["id"] = existing.id # TODO (betodealmeida): move this logic to import_from_dict config = config.copy() for key in JSON_KEYS: if config.get(key) is not None: try: config[key] = json.dumps(config[key]) except TypeError: logger.info("Unable to encode `%s` field: %s", key, config[key]) for metric in config.get("metrics", []): if metric.get("extra") is not None: try: metric["extra"] = json.dumps(metric["extra"]) except TypeError: logger.info("Unable to encode `extra` field: %s", metric["extra"]) metric["extra"] = None # should we delete columns and metrics not present in the current import? sync = ["columns", "metrics"] if overwrite else [] # should we also load data into the dataset? data_uri = config.get("data") # import recursively to include columns and metrics dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync) if dataset.id is None: session.flush() example_database = get_example_database() try: table_exists = example_database.has_table_by_name(dataset.table_name) except Exception: # pylint: disable=broad-except # MySQL doesn't play nice with GSheets table names logger.warning("Couldn't check if table %s exists, assuming it does", dataset.table_name) table_exists = True if data_uri and (not table_exists or force_data): logger.info("Downloading data from %s", data_uri) load_data(data_uri, dataset, example_database, session) if hasattr(g, "user") and g.user: dataset.owners.append(g.user) return dataset
def test_update_physical_sqlatable_metrics( mocker: MockFixture, app_context: None, session: Session, get_session: Callable[[], Session], ) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. For this test we check that updating the SQL expression in a metric belonging to a ``SqlaTable`` is reflected in the ``Dataset`` metric. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] metrics = [ SqlMetric(metric_name="cnt", expression="COUNT(*)"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=metrics, database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() # check that the metric was created # 1 physical column for table + (1 column + 1 metric for datasets) assert session.query(Column).count() == 3 column = session.query(Column).filter_by(is_physical=False).one() assert column.expression == "COUNT(*)" # change the metric definition sqla_table.metrics[0].expression = "MAX(ds)" session.flush() assert column.expression == "MAX(ds)" # in a new session, update new columns and metrics at the same time # reload the sqla_table so we can test the case that accessing an not already # loaded attribute (`sqla_table.metrics`) while there are updates on the instance # may trigger `after_update` before the attribute is loaded session = get_session() sqla_table = session.query(SqlaTable).filter( SqlaTable.id == sqla_table.id).one() sqla_table.columns.append( TableColumn( column_name="another_column", is_dttm=0, type="TIMESTAMP", expression="concat('a', 'b')", )) # Here `SqlaTable.after_update` is triggered # before `sqla_table.metrics` is loaded sqla_table.metrics.append( SqlMetric(metric_name="another_metric", expression="COUNT(*)")) # `SqlaTable.after_update` will trigger again at flushing session.flush() assert session.query(Column).count() == 5
def test_set_perm_sqla_table(self): table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", database=get_example_database(), ) db.session.add(table) db.session.commit() stored_table = (db.session.query(SqlaTable).filter_by( table_name="tmp_perm_table").one()) self.assertEqual(stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # table name change stored_table.table_name = "tmp_perm_table_v2" db.session.commit() stored_table = (db.session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # schema name change stored_table.schema = "tmp_schema_v2" db.session.commit() stored_table = (db.session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # database change new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db") db.session.add(new_db) stored_table.database = (db.session.query(Database).filter_by( database_name="tmp_db").one()) db.session.commit() stored_table = (db.session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # no schema stored_table.schema = None db.session.commit() stored_table = (db.session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) self.assertIsNone(stored_table.schema_perm) db.session.delete(new_db) db.session.delete(stored_table) db.session.commit()
def create_table_permissions(table: models.SqlaTable) -> None: security_manager.add_permission_view_menu("datasource_access", table.get_perm()) if table.schema: security_manager.add_permission_view_menu("schema_access", table.schema_perm)
def test_update_physical_sqlatable_no_dataset(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test updating the table on a physical dataset that it creates a new dataset if one didn't already exist. When updating the table on a physical dataset by pointing it somewhere else (change in database ID, schema, or table name) we should point the ``Dataset`` to an existing ``Table`` if possible, and create a new one otherwise. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) mocker.patch("superset.datasets.dao.db.session", session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table from superset.tables.schemas import TableSchema engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="a", type="INTEGER"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() # check that the table was created table = session.query(Table).one() assert table.id == 1 dataset = session.query(Dataset).one() assert dataset.tables == [table] # point ``SqlaTable`` to a different database new_database = Database(database_name="my_other_database", sqlalchemy_uri="sqlite://") session.add(new_database) session.flush() sqla_table.database = new_database session.flush() new_dataset = session.query(Dataset).one() # check that dataset now points to the new table assert new_dataset.tables[0].database_id == 2 # point ``SqlaTable`` back sqla_table.database_id = 1 session.flush() # check that dataset points to the original table assert new_dataset.tables[0].database_id == 1
def test_dataset_attributes(app_context: None, session: Session) -> None: """ Test that checks attributes in the dataset. If this check fails it means new attributes were added to ``SqlaTable``, and ``SqlaTable.after_insert`` should be updated to handle them! """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="user_id", type="INTEGER"), TableColumn(column_name="revenue", type="INTEGER"), TableColumn(column_name="expenses", type="INTEGER"), TableColumn(column_name="profit", type="INTEGER", expression="revenue-expenses"), ] metrics = [ SqlMetric(metric_name="cnt", expression="COUNT(*)"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=metrics, main_dttm_col="ds", default_endpoint= "https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=None, params=json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }), perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) session.add(sqla_table) session.flush() dataset = session.query(SqlaTable).one() # If this test fails because attributes changed, make sure to update # ``SqlaTable.after_insert`` accordingly. assert sorted(dataset.__dict__.keys()) == [ "_sa_instance_state", "cache_timeout", "changed_by_fk", "changed_on", "columns", "created_by_fk", "created_on", "database", "database_id", "default_endpoint", "description", "external_url", "extra", "fetch_values_predicate", "filter_select_enabled", "id", "is_featured", "is_managed_externally", "is_sqllab_view", "main_dttm_col", "metrics", "offset", "params", "perm", "schema", "schema_perm", "sql", "table_name", "template_params", "uuid", ]
def test_update_virtual_sqlatable_references(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test that changing the SQL of a virtual ``SqlaTable`` updates ``Dataset``. When the SQL is modified the list of referenced tables should be updated in the new ``Dataset`` model. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") table1 = Table( name="table_a", schema="my_schema", catalog=None, database=database, columns=[Column(name="a", type="INTEGER")], ) table2 = Table( name="table_b", schema="my_schema", catalog=None, database=database, columns=[Column(name="b", type="INTEGER")], ) session.add(table1) session.add(table2) session.commit() # create virtual dataset columns = [TableColumn(column_name="a", type="INTEGER")] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, database=database, schema="my_schema", sql="SELECT a FROM table_a", ) session.add(sqla_table) session.flush() # check that new dataset has table1 dataset = session.query(Dataset).one() assert dataset.tables == [table1] # change SQL sqla_table.sql = "SELECT a, b FROM table_a JOIN table_b" session.flush() # check that new dataset has both tables new_dataset = session.query(Dataset).one() assert new_dataset.tables == [table1, table2] assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b"
def test_update_physical_sqlatable(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test updating the table on a physical dataset. When updating the table on a physical dataset by pointing it somewhere else (change in database ID, schema, or table name) we should point the ``Dataset`` to an existing ``Table`` if possible, and create a new one otherwise. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) mocker.patch("superset.datasets.dao.db.session", session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table from superset.tables.schemas import TableSchema engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="a", type="INTEGER"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() # check that the table was created, and that the created dataset points to it table = session.query(Table).one() assert table.id == 1 assert table.name == "old_dataset" assert table.schema is None assert table.database_id == 1 dataset = session.query(Dataset).one() assert dataset.tables == [table] # point ``SqlaTable`` to a different database new_database = Database(database_name="my_other_database", sqlalchemy_uri="sqlite://") session.add(new_database) session.flush() sqla_table.database = new_database session.flush() # ignore these keys when comparing results ignored_keys = {"created_on", "changed_on", "uuid"} # check that the old table still exists, and that the dataset points to the newly # created table (id=2) and column (id=2), on the new database (also id=2) table_schema = TableSchema() tables = [{ k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys } for table in session.query(Table).all()] assert tables == [ { "created_by": None, "extra_json": "{}", "name": "old_dataset", "changed_by": None, "catalog": None, "columns": [1], "database": 1, "external_url": None, "schema": None, "id": 1, "is_managed_externally": False, }, { "created_by": None, "extra_json": "{}", "name": "old_dataset", "changed_by": None, "catalog": None, "columns": [2], "database": 2, "external_url": None, "schema": None, "id": 2, "is_managed_externally": False, }, ] # check that dataset now points to the new table assert dataset.tables[0].database_id == 2 # point ``SqlaTable`` back sqla_table.database_id = 1 session.flush() # check that dataset points to the original table assert dataset.tables[0].database_id == 1
def test_import_dataset_duplicate_column(session: Session) -> None: """ Test importing a dataset with a column that already exists. """ from superset.columns.models import Column as NewColumn from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member dataset_uuid = uuid.uuid4() database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() dataset = SqlaTable(uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id) column = TableColumn(column_name="existing_column") session.add(dataset) session.add(column) session.flush() config = { "table_name": dataset.table_name, "main_dttm_col": "ds", "description": "This is the description", "default_endpoint": None, "offset": -8, "cache_timeout": 3600, "schema": "my_schema", "sql": None, "params": { "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }, "template_params": { "answer": "42", }, "filter_select_enabled": True, "fetch_values_predicate": "foo IN (1, 2)", "extra": { "warning_markdown": "*WARNING*" }, "uuid": dataset_uuid, "metrics": [{ "metric_name": "cnt", "verbose_name": None, "metric_type": None, "expression": "COUNT(*)", "description": None, "d3format": None, "extra": { "warning_markdown": None }, "warning_text": None, }], "columns": [{ "column_name": column.column_name, "verbose_name": None, "is_dttm": None, "is_active": None, "type": "INTEGER", "groupby": None, "filterable": None, "expression": "revenue-expenses", "description": None, "python_date_format": None, "extra": { "certified_by": "User", }, }], "database_uuid": database.uuid, "database_id": database.id, } sqla_table = import_dataset(session, config, overwrite=True) assert sqla_table.table_name == dataset.table_name assert sqla_table.main_dttm_col == "ds" assert sqla_table.description == "This is the description" assert sqla_table.default_endpoint is None assert sqla_table.offset == -8 assert sqla_table.cache_timeout == 3600 assert sqla_table.schema == "my_schema" assert sqla_table.sql is None assert sqla_table.params == json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834 }) assert sqla_table.template_params == json.dumps({"answer": "42"}) assert sqla_table.filter_select_enabled is True assert sqla_table.fetch_values_predicate == "foo IN (1, 2)" assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}' assert sqla_table.uuid == dataset_uuid assert len(sqla_table.metrics) == 1 assert sqla_table.metrics[0].metric_name == "cnt" assert sqla_table.metrics[0].verbose_name is None assert sqla_table.metrics[0].metric_type is None assert sqla_table.metrics[0].expression == "COUNT(*)" assert sqla_table.metrics[0].description is None assert sqla_table.metrics[0].d3format is None assert sqla_table.metrics[0].extra == '{"warning_markdown": null}' assert sqla_table.metrics[0].warning_text is None assert len(sqla_table.columns) == 1 assert sqla_table.columns[0].column_name == column.column_name assert sqla_table.columns[0].verbose_name is None assert sqla_table.columns[0].is_dttm is False assert sqla_table.columns[0].is_active is True assert sqla_table.columns[0].type == "INTEGER" assert sqla_table.columns[0].groupby is True assert sqla_table.columns[0].filterable is True assert sqla_table.columns[0].expression == "revenue-expenses" assert sqla_table.columns[0].description is None assert sqla_table.columns[0].python_date_format is None assert sqla_table.columns[0].extra == '{"certified_by": "User"}' assert sqla_table.database.uuid == database.uuid assert sqla_table.database.id == database.id
def test_import_table_no_metadata(self): db_id = get_example_database().id table = self.create_table("pure_table", id=10001) imported_id = SqlaTable.import_obj(table, db_id, import_time=1989) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported)
def test_set_perm_sqla_table(self): security_manager.on_view_menu_after_insert = Mock() security_manager.on_permission_view_after_insert = Mock() session = db.session table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", database=get_example_database(), ) session.add(table) session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table").one()) self.assertEqual(stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})") pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", stored_table.perm) pvm_schema = security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm) self.assertIsNotNone(pvm_dataset) self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone(pvm_schema) # assert on permission hooks view_menu_dataset = security_manager.find_view_menu( f"[examples].[tmp_perm_table](id:{stored_table.id})") view_menu_schema = security_manager.find_view_menu( f"[examples].[tmp_schema]") security_manager.on_view_menu_after_insert.assert_has_calls([ call(ANY, ANY, view_menu_dataset), call(ANY, ANY, view_menu_schema), ]) security_manager.on_permission_view_after_insert.assert_has_calls([ call(ANY, ANY, pvm_dataset), call(ANY, ANY, pvm_schema), ]) # table name change stored_table.table_name = "tmp_perm_table_v2" session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # schema name change stored_table.schema = "tmp_schema_v2" session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # database change new_db = Database(sqlalchemy_uri="sqlite://", database_name="tmp_db") session.add(new_db) stored_table.database = (session.query(Database).filter_by( database_name="tmp_db").one()) session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # no schema stored_table.schema = None session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) self.assertIsNone(stored_table.schema_perm) session.delete(new_db) session.delete(stored_table) session.commit()
def test_comments_in_sqlatable_query(self): clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl" commented_query = '/* comment 1 */' + clean_query + '-- comment 2' table = SqlaTable(sql=commented_query) rendered_query = str(table.get_from_clause()) self.assertEqual(clean_query, rendered_query)
def create_table_in_view(self): self.test_database = utils.get_example_database() self.test_database.allow_dml = True params = {"remote_id": 1234, "database_name": self.test_database.name} self.test_table = SqlaTable(id=1234, table_name="Departments", params=json.dumps(params)) self.test_table.columns.append( TableColumn(column_name="department_id", type="INTEGER")) self.test_table.columns.append( TableColumn(column_name="name", type="STRING")) self.test_table.columns.append( TableColumn(column_name="street", type="STRING")) self.test_table.columns.append( TableColumn(column_name="city", type="STRING")) self.test_table.columns.append( TableColumn(column_name="country", type="STRING")) self.test_table.columns.append( TableColumn(column_name="lat", type="FLOAT")) self.test_table.columns.append( TableColumn(column_name="lon", type="FLOAT")) self.test_table.database = self.test_database self.test_table.database_id = self.test_table.database.id db.session.add(self.test_table) db.session.commit() data = { "department_id": [1, 2, 3, 4, 5], "name": [ "Logistics", "Marketing", "Facility Management", "Personal", "Finances", ], "street": [ "Oberseestrasse 10", "Grossmünsterplatz", "Uetliberg", "Zürichbergstrasse 221", "Bahnhofstrasse", ], "city": ["Rapperswil", "Zürich", "Zürich", "Zürich", "Zürich"], "country": [ "Switzerland", "Switzerland", "Switzerland", "Switzerland", "Switzerland", ], "lat": [None, None, None, None, 4.789], "lon": [None, None, None, 1.234, None], } df = pd.DataFrame(data=data) # because of caching problem with postgres load database a second time # without this, the sqla engine throws an exception database = utils.get_example_database() if database: engine = database.get_sqla_engine() df.to_sql( self.test_table.table_name, engine, if_exists="replace", chunksize=500, dtype={ "department_id": Integer, "name": String(60), "street": String(60), "city": String(60), "country": String(60), "lat": Float, "lon": Float, }, index=False, )
def test_extra_cache_keys(self, flask_g): flask_g.user.username = "******" base_query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": ["user"], "metrics": [], "is_timeseries": False, "filter": [], } # Table with Jinja callable. table1 = SqlaTable( table_name="test_has_extra_cache_keys_table", sql="SELECT '{{ current_username() }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table1.get_extra_cache_keys(query_obj) self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] # Table with Jinja callable disabled. table2 = SqlaTable( table_name="test_has_extra_cache_keys_disabled_table", sql="SELECT '{{ current_username(False) }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table2.get_extra_cache_keys(query_obj) self.assertTrue(table2.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # Table with no Jinja callable. query = "SELECT 'abc' as user" table3 = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql=query, database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) extra_cache_keys = table3.get_extra_cache_keys(query_obj) self.assertFalse(table3.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # With Jinja callable in SQL expression. query_obj = dict(**base_query_obj, extras={ "where": "(user != '{{ current_username() }}')" }) extra_cache_keys = table3.get_extra_cache_keys(query_obj) self.assertTrue(table3.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] # Cleanup for table in [table1, table2, table3]: db.session.delete(table) db.session.commit()
def test__normalize_prequery_result_type( app_context: Flask, mocker: MockFixture, row: pd.Series, dimension: str, result: Any, ) -> None: def _convert_dttm( target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None) -> Optional[str]: if target_type.upper() == TemporalType.TIMESTAMP: return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" return None table = SqlaTable(table_name="foobar", database=get_example_database()) mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm) columns_by_name = { "foo": TableColumn( column_name="foo", is_dttm=False, table=table, type="STRING", ), "bar": TableColumn( column_name="bar", is_dttm=False, table=table, type="BOOLEAN", ), "baz": TableColumn( column_name="baz", is_dttm=False, table=table, type="INTEGER", ), "qux": TableColumn( column_name="qux", is_dttm=False, table=table, type="FLOAT", ), "quux": TableColumn( column_name="quuz", is_dttm=True, table=table, type="STRING", ), "quuz": TableColumn( column_name="quux", is_dttm=True, table=table, type="TIMESTAMP", ), } normalized = table._normalize_prequery_result_type( row, dimension, columns_by_name, ) assert type(normalized) == type(result) if isinstance(normalized, TextClause): assert str(normalized) == str(result) else: assert normalized == result
def invalidate(self) -> Response: """ Takes a list of datasources, finds the associated cache records and invalidates them and removes the database records --- post: description: >- Takes a list of datasources, finds the associated cache records and invalidates them and removes the database records requestBody: description: >- A list of datasources uuid or the tuples of database and datasource names required: true content: application/json: schema: $ref: "#/components/schemas/CacheInvalidationRequestSchema" responses: 201: description: cache was successfully invalidated 400: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' """ try: datasources = CacheInvalidationRequestSchema().load(request.json) except KeyError: return self.response_400(message="Request is incorrect") except ValidationError as error: return self.response_400(message=str(error)) datasource_uids = set(datasources.get("datasource_uids", [])) for ds in datasources.get("datasources", []): ds_obj = SqlaTable.get_datasource_by_name( session=db.session, datasource_name=ds.get("datasource_name"), schema=ds.get("schema"), database_name=ds.get("database_name"), ) if ds_obj: datasource_uids.add(ds_obj.uid) cache_key_objs = (db.session.query(CacheKey).filter( CacheKey.datasource_uid.in_(datasource_uids)).all()) cache_keys = [c.cache_key for c in cache_key_objs] if cache_key_objs: all_keys_deleted = cache_manager.cache.delete_many(*cache_keys) if not all_keys_deleted: # expected behavior as keys may expire and cache is not a # persistent storage logger.info( "Some of the cache keys were not deleted in the list %s", cache_keys) try: delete_stmt = ( CacheKey.__table__.delete().where( # pylint: disable=no-member CacheKey.cache_key.in_(cache_keys))) db.session.execute(delete_stmt) db.session.commit() self.stats_logger.gauge("invalidated_cache", len(cache_keys)) logger.info( "Invalidated %s cache records for %s datasources", len(cache_keys), len(datasource_uids), ) except SQLAlchemyError as ex: # pragma: no cover logger.error(ex, exc_info=True) db.session.rollback() return self.response_500(str(ex)) db.session.commit() return self.response(201)
def test_create_physical_sqlatable( app_context: None, session: Session, sample_columns: Dict["TableColumn", Dict[str, Any]], sample_metrics: Dict["SqlMetric", Dict[str, Any]], columns_default: Dict[str, Any], ) -> None: """ Test shadow write when creating a new ``SqlaTable``. When a new physical ``SqlaTable`` is created, new models should also be created for ``Dataset``, ``Table``, and ``Column``. """ from superset.columns.models import Column from superset.columns.schemas import ColumnSchema from superset.connectors.sqla.models import SqlaTable from superset.datasets.models import Dataset from superset.datasets.schemas import DatasetSchema from superset.models.core import Database from superset.tables.models import Table from superset.tables.schemas import TableSchema engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member user1 = get_test_user(1, "abc") columns = list(sample_columns.keys()) metrics = list(sample_metrics.keys()) expected_table_columns = list(sample_columns.values()) expected_metric_columns = list(sample_metrics.values()) sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=metrics, main_dttm_col="ds", default_endpoint= "https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=None, params=json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }), created_by=user1, changed_by=user1, owners=[user1], perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) session.add(sqla_table) session.flush() # ignore these keys when comparing results ignored_keys = {"created_on", "changed_on"} # check that columns were created column_schema = ColumnSchema() actual_columns = [{ k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys } for column in session.query(Column).all()] num_physical_columns = len([ col for col in expected_table_columns if col.get("is_physical") == True ]) num_dataset_table_columns = len(columns) num_dataset_metric_columns = len(metrics) assert (len(actual_columns) == num_physical_columns + num_dataset_table_columns + num_dataset_metric_columns) # table columns are created before dataset columns are created offset = 0 for i in range(num_physical_columns): assert actual_columns[i + offset] == { **columns_default, **expected_table_columns[i], "id": i + offset + 1, # physical columns for table have its own uuid "uuid": actual_columns[i + offset]["uuid"], "is_physical": True, # table columns do not have creators "created_by": None, "tables": [1], } offset += num_physical_columns for i, column in enumerate(sqla_table.columns): assert actual_columns[i + offset] == { **columns_default, **expected_table_columns[i], "id": i + offset + 1, # columns for dataset reuses the same uuid of TableColumn "uuid": str(column.uuid), "datasets": [1], } offset += num_dataset_table_columns for i, metric in enumerate(sqla_table.metrics): assert actual_columns[i + offset] == { **columns_default, **expected_metric_columns[i], "id": i + offset + 1, "uuid": str(metric.uuid), "datasets": [1], } # check that table was created table_schema = TableSchema() tables = [{ k: v for k, v in table_schema.dump(table).items() if k not in (ignored_keys | {"uuid"}) } for table in session.query(Table).all()] assert len(tables) == 1 assert tables[0] == { "id": 1, "database": 1, "created_by": 1, "changed_by": 1, "datasets": [1], "columns": [1, 2, 3], "extra_json": "{}", "catalog": None, "schema": "my_schema", "name": "old_dataset", "is_managed_externally": False, "external_url": None, } # check that dataset was created dataset_schema = DatasetSchema() datasets = [{ k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys } for dataset in session.query(Dataset).all()] assert len(datasets) == 1 assert datasets[0] == { "id": 1, "uuid": str(sqla_table.uuid), "created_by": 1, "changed_by": 1, "owners": [1], "name": "old_dataset", "columns": [4, 5, 6, 7, 8, 9], "is_physical": True, "database": 1, "tables": [1], "extra_json": "{}", "expression": "old_dataset", "is_managed_externally": False, "external_url": None, }
def test_import_table_no_metadata(self): table = self.create_table('pure_table', id=10001) imported_id = SqlaTable.import_obj(table, import_time=1989) imported = self.get_table(imported_id) self.assert_table_equals(table, imported)
def test_import_table_no_metadata(self): table = self.create_table('pure_table', id=10001) imported_id = SqlaTable.import_obj(table) imported = self.get_table(imported_id) self.assert_table_equals(table, imported)
def test_update_physical_sqlatable_database( mocker: MockFixture, app_context: None, session: Session, get_session: Callable[[], Session], ) -> None: """ Test updating the table on a physical dataset. When updating the table on a physical dataset by pointing it somewhere else (change in database ID, schema, or table name) we should point the ``Dataset`` to an existing ``Table`` if possible, and create a new one otherwise. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) mocker.patch("superset.datasets.dao.db.session", session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset, dataset_column_association_table from superset.models.core import Database from superset.tables.models import Table, table_column_association_table from superset.tables.schemas import TableSchema engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="a", type="INTEGER"), ] original_database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="original_table", columns=columns, metrics=[], database=original_database, ) session.add(sqla_table) session.flush() assert session.query(Table).count() == 1 assert session.query(Dataset).count() == 1 assert session.query(Column).count() == 2 # 1 for table, 1 for dataset # check that the table was created, and that the created dataset points to it table = session.query(Table).one() assert table.id == 1 assert table.name == "original_table" assert table.schema is None assert table.database_id == 1 dataset = session.query(Dataset).one() assert dataset.tables == [table] # point ``SqlaTable`` to a different database new_database = Database(database_name="my_other_database", sqlalchemy_uri="sqlite://") session.add(new_database) session.flush() sqla_table.database = new_database sqla_table.table_name = "new_table" session.flush() assert session.query(Dataset).count() == 1 assert session.query(Table).count() == 2 # <Column:id=1> is kept for the old table # <Column:id=2> is kept for the updated dataset # <Column:id=3> is created for the new table assert session.query(Column).count() == 3 # ignore these keys when comparing results ignored_keys = {"created_on", "changed_on", "uuid"} # check that the old table still exists, and that the dataset points to the newly # created table, column and dataset table_schema = TableSchema() tables = [{ k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys } for table in session.query(Table).all()] assert tables[0] == { "id": 1, "database": 1, "columns": [1], "datasets": [], "created_by": None, "changed_by": None, "extra_json": "{}", "catalog": None, "schema": None, "name": "original_table", "external_url": None, "is_managed_externally": False, } assert tables[1] == { "id": 2, "database": 2, "datasets": [1], "columns": [3], "created_by": None, "changed_by": None, "catalog": None, "schema": None, "name": "new_table", "is_managed_externally": False, "extra_json": "{}", "external_url": None, } # check that dataset now points to the new table assert dataset.tables[0].database_id == 2 # and a new column is created assert len(dataset.columns) == 1 assert dataset.columns[0].id == 2 # point ``SqlaTable`` back sqla_table.database = original_database sqla_table.table_name = "original_table" session.flush() # should not create more table and datasets assert session.query(Dataset).count() == 1 assert session.query(Table).count() == 2 # <Column:id=1> is deleted for the old table # <Column:id=2> is kept for the updated dataset # <Column:id=3> is kept for the new table assert session.query(Column.id).order_by(Column.id).all() == [ (1, ), (2, ), (3, ), ] assert session.query(dataset_column_association_table).all() == [(1, 2)] assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)] assert session.query(Dataset).filter_by(id=1).one().columns[0].id == 2 assert session.query(Table).filter_by(id=2).one().columns[0].id == 3 assert session.query(Table).filter_by(id=1).one().columns[0].id == 1 # the dataset points back to the original table assert dataset.tables[0].database_id == 1 assert dataset.tables[0].name == "original_table" # kept the original column assert dataset.columns[0].id == 2 session.commit() session.close() # querying in a new session should still return the same result session = get_session() assert session.query(table_column_association_table).all() == [(1, 1), (2, 3)]
def form_post(self, form: CsvToDatabaseForm) -> Response: database = form.con.data csv_table = Table(table=form.name.data, schema=form.schema.data) if not schema_allows_csv_upload(database, csv_table.schema): message = _( 'Database "%(database_name)s" schema "%(schema_name)s" ' "is not allowed for csv uploads. Please contact your Superset Admin.", database_name=database.database_name, schema_name=csv_table.schema, ) flash(message, "danger") return redirect("/csvtodatabaseview/form") if "." in csv_table.table and csv_table.schema: message = _( "You cannot specify a namespace both in the name of the table: " '"%(csv_table.table)s" and in the schema field: ' '"%(csv_table.schema)s". Please remove one', table=csv_table.table, schema=csv_table.schema, ) flash(message, "danger") return redirect("/csvtodatabaseview/form") uploaded_tmp_file_path = tempfile.NamedTemporaryFile( dir=app.config["UPLOAD_FOLDER"], suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(), delete=False, ).name try: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) upload_stream_write(form.csv_file.data, uploaded_tmp_file_path) con = form.data.get("con") database = ( db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) # More can be found here: # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html csv_to_df_kwargs = { "sep": form.sep.data, "header": form.header.data if form.header.data else 0, "index_col": form.index_col.data, "mangle_dupe_cols": form.mangle_dupe_cols.data, "skipinitialspace": form.skipinitialspace.data, "skiprows": form.skiprows.data, "nrows": form.nrows.data, "skip_blank_lines": form.skip_blank_lines.data, "parse_dates": form.parse_dates.data, "infer_datetime_format": form.infer_datetime_format.data, "chunksize": 1000, } if form.null_values.data: csv_to_df_kwargs["na_values"] = form.null_values.data csv_to_df_kwargs["keep_default_na"] = False # More can be found here: # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html df_to_sql_kwargs = { "name": csv_table.table, "if_exists": form.if_exists.data, "index": form.index.data, "index_label": form.index_label.data, "chunksize": 1000, } database.db_engine_spec.create_table_from_csv( uploaded_tmp_file_path, csv_table, database, csv_to_df_kwargs, df_to_sql_kwargs, ) # Connect table to the database that should be used for exploration. # E.g. if hive was used to upload a csv, presto will be a better option # to explore the table. expore_database = database explore_database_id = database.explore_database_id if explore_database_id: expore_database = ( db.session.query(models.Database) .filter_by(id=explore_database_id) .one_or_none() or database ) sqla_table = ( db.session.query(SqlaTable) .filter_by( table_name=csv_table.table, schema=csv_table.schema, database_id=expore_database.id, ) .one_or_none() ) if sqla_table: sqla_table.fetch_metadata() if not sqla_table: sqla_table = SqlaTable(table_name=csv_table.table) sqla_table.database = expore_database sqla_table.database_id = database.id sqla_table.user_id = g.user.id sqla_table.schema = csv_table.schema sqla_table.fetch_metadata() db.session.add(sqla_table) db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() try: os.remove(uploaded_tmp_file_path) except OSError: pass message = _( 'Unable to upload CSV file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' "Error message: %(error_msg)s", filename=form.csv_file.data.filename, table_name=form.name.data, db_name=database.database_name, error_msg=str(ex), ) flash(message, "danger") stats_logger.incr("failed_csv_upload") return redirect("/csvtodatabaseview/form") os.remove(uploaded_tmp_file_path) # Go back to welcome page / splash screen message = _( 'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in ' 'database "%(db_name)s"', csv_filename=form.csv_file.data.filename, table_name=str(csv_table), db_name=sqla_table.database.database_name, ) flash(message, "info") stats_logger.incr("successful_csv_upload") return redirect("/tablemodelview/list/")
def test_update_physical_sqlatable_columns(mocker: MockFixture, session: Session) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() assert session.query(Table).count() == 1 assert session.query(Dataset).count() == 1 assert session.query(Column).count() == 2 # 1 for table, 1 for dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # add a column to the original ``SqlaTable`` instance sqla_table.columns.append( TableColumn(column_name="num_boys", type="INTEGER")) session.flush() assert session.query(Column).count() == 3 dataset = session.query(Dataset).one() assert len(dataset.columns) == 2 # check that both lists have the same uuids assert [col.uuid for col in sqla_table.columns ].sort() == [col.uuid for col in dataset.columns].sort() # delete the column in the original instance sqla_table.columns = sqla_table.columns[1:] session.flush() # check that the column was added to the dataset and the added columns have # the correct uuid. assert session.query(TableColumn).count() == 1 # the extra Dataset.column is deleted, but Table.column is kept assert session.query(Column).count() == 2 # check that the column was also removed from the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # modify the attribute in a column sqla_table.columns[0].is_dttm = True session.flush() # check that the dataset column was modified dataset = session.query(Dataset).one() assert dataset.columns[0].is_temporal is True
def form_post(self, form): database = form.con.data excel_table = Table(table=form.name.data, schema=form.schema.data) if not self.is_schema_allowed(database, excel_table.schema): message = _( 'Database "%(database_name)s" schema "%(schema_name)s" ' "is not allowed for excel uploads. Please contact your Superset Admin.", database_name=database.database_name, schema_name=excel_table.schema, ) flash(message, "danger") return redirect("/exceltodatabaseview/form") if "." in excel_table.table and excel_table.schema: message = _( "You cannot specify a namespace both in the name of the table: " '"%(excel_table.table)s" and in the schema field: ' '"%(excel_table.schema)s". Please remove one', table=excel_table.table, schema=excel_table.schema, ) flash(message, "danger") return redirect("/exceltodatabaseview/form") uploaded_tmp_file_path = tempfile.NamedTemporaryFile( dir=app.config["UPLOAD_FOLDER"], suffix=os.path.splitext(form.excel_file.data.filename)[1].lower(), delete=False, ).name try: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) upload_stream_write(form.excel_file.data, uploaded_tmp_file_path) con = form.data.get("con") database = ( db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) excel_to_df_kwargs = { "header": form.header.data if form.header.data else 0, "index_col": form.index_col.data, "mangle_dupe_cols": form.mangle_dupe_cols.data, "skipinitialspace": form.skipinitialspace.data, "skiprows": form.skiprows.data, "nrows": form.nrows.data, "sheet_name": form.sheet_name.data, "chunksize": 1000, } df_to_sql_kwargs = { "name": excel_table.table, "if_exists": form.if_exists.data, "index": form.index.data, "index_label": form.index_label.data, "chunksize": 1000, } database.db_engine_spec.create_table_from_excel( uploaded_tmp_file_path, excel_table, database, excel_to_df_kwargs, df_to_sql_kwargs, ) # Connect table to the database that should be used for exploration. # E.g. if hive was used to upload a excel, presto will be a better option # to explore the table. expore_database = database explore_database_id = database.get_extra().get("explore_database_id", None) if explore_database_id: expore_database = ( db.session.query(models.Database) .filter_by(id=explore_database_id) .one_or_none() or database ) sqla_table = ( db.session.query(SqlaTable) .filter_by( table_name=excel_table.table, schema=excel_table.schema, database_id=expore_database.id, ) .one_or_none() ) if sqla_table: sqla_table.fetch_metadata() if not sqla_table: sqla_table = SqlaTable(table_name=excel_table.table) sqla_table.database = expore_database sqla_table.database_id = database.id sqla_table.user_id = g.user.id sqla_table.schema = excel_table.schema sqla_table.fetch_metadata() db.session.add(sqla_table) db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() try: os.remove(uploaded_tmp_file_path) except OSError: pass message = _( 'Unable to upload Excel file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' "Error message: %(error_msg)s", filename=form.excel_file.data.filename, table_name=form.name.data, db_name=database.database_name, error_msg=str(ex), ) flash(message, "danger") stats_logger.incr("failed_excel_upload") return redirect("/exceltodatabaseview/form") os.remove(uploaded_tmp_file_path) # Go back to welcome page / splash screen message = _( 'CSV file "%(excel_filename)s" uploaded to table "%(table_name)s" in ' 'database "%(db_name)s"', excel_filename=form.excel_file.data.filename, table_name=str(excel_table), db_name=sqla_table.database.database_name, ) flash(message, "info") stats_logger.incr("successful_excel_upload") return redirect("/tablemodelview/list/")
def test_create_virtual_sqlatable(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test shadow write when creating a new ``SqlaTable``. When a new virtual ``SqlaTable`` is created, new models should also be created for ``Dataset`` and ``Column``. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.columns.schemas import ColumnSchema from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.models import Dataset from superset.datasets.schemas import DatasetSchema from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member # create the ``Table`` that the virtual dataset points to database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") table = Table( name="some_table", schema="my_schema", catalog=None, database=database, columns=[ Column(name="ds", is_temporal=True, type="TIMESTAMP"), Column(name="user_id", type="INTEGER"), Column(name="revenue", type="INTEGER"), Column(name="expenses", type="INTEGER"), ], ) session.add(table) session.commit() # create virtual dataset columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="user_id", type="INTEGER"), TableColumn(column_name="revenue", type="INTEGER"), TableColumn(column_name="expenses", type="INTEGER"), TableColumn(column_name="profit", type="INTEGER", expression="revenue-expenses"), ] metrics = [ SqlMetric(metric_name="cnt", expression="COUNT(*)"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=metrics, main_dttm_col="ds", default_endpoint= "https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used database=database, offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=""" SELECT ds, user_id, revenue, expenses, revenue - expenses AS profit FROM some_table""", params=json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }), perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) session.add(sqla_table) session.flush() # ignore these keys when comparing results ignored_keys = {"created_on", "changed_on", "uuid"} # check that columns were created column_schema = ColumnSchema() column_schemas = [{ k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys } for column in session.query(Column).all()] assert column_schemas == [ { "type": "TIMESTAMP", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": None, "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "ds", "is_physical": True, "changed_by": None, "is_temporal": True, "id": 1, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": None, "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "user_id", "is_physical": True, "changed_by": None, "is_temporal": False, "id": 2, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": None, "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "revenue", "is_physical": True, "changed_by": None, "is_temporal": False, "id": 3, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": None, "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "expenses", "is_physical": True, "changed_by": None, "is_temporal": False, "id": 4, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "TIMESTAMP", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "ds", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "ds", "is_physical": False, "changed_by": None, "is_temporal": True, "id": 5, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "user_id", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "user_id", "is_physical": False, "changed_by": None, "is_temporal": False, "id": 6, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "revenue", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "revenue", "is_physical": False, "changed_by": None, "is_temporal": False, "id": 7, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "expenses", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "expenses", "is_physical": False, "changed_by": None, "is_temporal": False, "id": 8, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "INTEGER", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "revenue-expenses", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "profit", "is_physical": False, "changed_by": None, "is_temporal": False, "id": 9, "is_aggregation": False, "external_url": None, "is_managed_externally": False, }, { "type": "Unknown", "is_additive": False, "extra_json": "{}", "is_partition": False, "expression": "COUNT(*)", "unit": None, "warning_text": None, "created_by": None, "is_increase_desired": True, "description": None, "is_spatial": False, "name": "cnt", "is_physical": False, "changed_by": None, "is_temporal": False, "id": 10, "is_aggregation": True, "external_url": None, "is_managed_externally": False, }, ] # check that dataset was created, and has a reference to the table dataset_schema = DatasetSchema() datasets = [{ k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys } for dataset in session.query(Dataset).all()] assert datasets == [{ "id": 1, "sqlatable_id": 1, "name": "old_dataset", "changed_by": None, "created_by": None, "columns": [5, 6, 7, 8, 9, 10], "is_physical": False, "tables": [1], "extra_json": "{}", "external_url": None, "is_managed_externally": False, "expression": """ SELECT ds, user_id, revenue, expenses, revenue - expenses AS profit FROM some_table""", }]
def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None: datasource.main_dttm_col = "ds" datasource.database = database datasource.filter_select_enabled = True datasource.fetch_metadata()
def test_sql_lab_insert_rls( mocker: MockerFixture, session: Session, ) -> None: """ Integration test for `insert_rls`. """ from flask_appbuilder.security.sqla.models import Role, User from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.sql_lab import Query from superset.security.manager import SupersetSecurityManager from superset.sql_lab import execute_sql_statement from superset.utils.core import RowLevelSecurityFilterType engine = session.connection().engine Query.metadata.create_all(engine) # pylint: disable=no-member connection = engine.raw_connection() connection.execute("CREATE TABLE t (c INTEGER)") for i in range(10): connection.execute("INSERT INTO t VALUES (?)", (i,)) cursor = connection.cursor() query = Query( sql="SELECT c FROM t", client_id="abcde", database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), schema=None, limit=5, select_as_cta_used=False, ) session.add(query) session.commit() admin = User( first_name="Alice", last_name="Doe", email="*****@*****.**", username="******", roles=[Role(name="Admin")], ) # first without RLS with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert ( superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 4 |""".strip() ) assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" # now with RLS rls = RowLevelSecurityFilter( name="sqllab_rls1", filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], roles=[admin.roles[0]], group_key=None, clause="c > 5", ) session.add(rls) session.flush() mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert ( superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 6 | | 1 | 7 | | 2 | 8 | | 3 | 9 |""".strip() ) assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
def raise_for_access( # pylint: disable=too-many-arguments,too-many-branches, # pylint: disable=too-many-locals self, database: Optional["Database"] = None, datasource: Optional["BaseDatasource"] = None, query: Optional["Query"] = None, query_context: Optional["QueryContext"] = None, table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, ) -> None: """ Raise an exception if the user cannot access the resource. :param database: The Superset database :param datasource: The Superset datasource :param query: The SQL Lab query :param query_context: The query context :param table: The Superset table (requires database) :param viz: The visualization :raises SupersetSecurityException: If the user cannot access the resource """ from superset.connectors.sqla.models import SqlaTable from superset.sql_parse import Table if database and table or query: if query: database = query.database database = cast("Database", database) if self.can_access_database(database): return if query: tables = { Table(table_.table, table_.schema or query.schema) for table_ in sql_parse.ParsedQuery(query.sql).tables } elif table: tables = {table} denied = set() for table_ in tables: schema_perm = self.get_schema_perm(database, schema=table_.schema) if not (schema_perm and self.can_access("schema_access", schema_perm)): datasources = SqlaTable.query_datasources_by_name( self.get_session, database, table_.table, schema=table_.schema) # Access to any datasource is suffice. for datasource_ in datasources: if self.can_access("datasource_access", datasource_.perm): break else: denied.add(table_) if denied: raise SupersetSecurityException( self.get_table_access_error_object(denied)) if datasource or query_context or viz: if query_context: datasource = query_context.datasource elif viz: datasource = viz.datasource assert datasource from superset.extensions import feature_flag_manager if not (self.can_access_schema(datasource) or self.can_access( "datasource_access", datasource.perm or "") or (feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC") and self.can_access_based_on_dashboard(datasource))): raise SupersetSecurityException( self.get_datasource_access_error_object(datasource))
def test_comments_in_sqlatable_query(self): clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl" commented_query = '/* comment 1 */' + clean_query + '-- comment 2' table = SqlaTable(sql=commented_query) rendered_query = text_type(table.get_from_clause()) self.assertEqual(clean_query, rendered_query)
def export_dashboards( # pylint: disable=too-many-locals cls, dashboard_ids: List[int]) -> str: copied_dashboards = [] datasource_ids = set() for dashboard_id in dashboard_ids: # make sure that dashboard_id is an integer dashboard_id = int(dashboard_id) dashboard = (db.session.query(Dashboard).options( subqueryload( Dashboard.slices)).filter_by(id=dashboard_id).first()) # remove ids and relations (like owners, created by, slices, ...) copied_dashboard = dashboard.copy() for slc in dashboard.slices: datasource_ids.add((slc.datasource_id, slc.datasource_type)) copied_slc = slc.copy() # save original id into json # we need it to update dashboard's json metadata on import copied_slc.id = slc.id # add extra params for the import copied_slc.alter_params( remote_id=slc.id, datasource_name=slc.datasource.datasource_name, schema=slc.datasource.schema, database_name=slc.datasource.database.name, ) # set slices without creating ORM relations slices = copied_dashboard.__dict__.setdefault("slices", []) slices.append(copied_slc) json_metadata = json.loads(dashboard.json_metadata) native_filter_configuration: List[Dict[ str, Any]] = json_metadata.get("native_filter_configuration", []) for native_filter in native_filter_configuration: session = db.session() for target in native_filter.get("targets", []): id_ = target.get("datasetId") if id_ is None: continue datasource = DatasourceDAO.get_datasource( session, utils.DatasourceType.TABLE, id_) datasource_ids.add((datasource.id, datasource.type)) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) eager_datasources = [] for datasource_id, _ in datasource_ids: eager_datasource = SqlaTable.get_eager_sqlatable_datasource( db.session, datasource_id) copied_datasource = eager_datasource.copy() copied_datasource.alter_params( remote_id=eager_datasource.id, database_name=eager_datasource.database.name, ) datasource_class = copied_datasource.__class__ for field_name in datasource_class.export_children: field_val = getattr(eager_datasource, field_name).copy() # set children without creating ORM relations copied_datasource.__dict__[field_name] = field_val eager_datasources.append(copied_datasource) return json.dumps( { "dashboards": copied_dashboards, "datasources": eager_datasources }, cls=utils.DashboardEncoder, indent=4, )