def test_db_column_types(self): test_cases: Dict[str, GenericDataType] = { # string "CHAR": GenericDataType.STRING, "VARCHAR": GenericDataType.STRING, "NVARCHAR": GenericDataType.STRING, "STRING": GenericDataType.STRING, "TEXT": GenericDataType.STRING, "NTEXT": GenericDataType.STRING, # numeric "INTEGER": GenericDataType.NUMERIC, "BIGINT": GenericDataType.NUMERIC, "DECIMAL": GenericDataType.NUMERIC, # temporal "DATE": GenericDataType.TEMPORAL, "DATETIME": GenericDataType.TEMPORAL, "TIME": GenericDataType.TEMPORAL, "TIMESTAMP": GenericDataType.TEMPORAL, } tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database()) for str_type, db_col_type in test_cases.items(): col = TableColumn(column_name="foo", type=str_type, table=tbl) self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL) self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC) self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING) for str_type, db_col_type in test_cases.items(): col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True) self.assertTrue(col.is_temporal)
def test_table(session: Session) -> "SqlaTable": """ Fixture that generates an in-memory table. """ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="id", type="INTEGER"), TableColumn(column_name="dttm", type="INTEGER"), TableColumn(column_name="duration_ms", type="INTEGER"), ] return SqlaTable( table_name="test_table", columns=columns, metrics=[], main_dttm_col=None, database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), )
def test_quote_expressions(app_context: None, session: Session) -> None: """ Test that expressions are quoted appropriately in columns and datasets. """ from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="has space", type="INTEGER"), TableColumn(column_name="no_need", type="INTEGER"), ] sqla_table = SqlaTable( table_name="old dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() dataset = session.query(Dataset).one() assert dataset.expression == '"old dataset"' assert dataset.columns[0].expression == '"has space"' assert dataset.columns[1].expression == "no_need"
def test_update_sqlatable(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # add a column to the original ``SqlaTable`` instance sqla_table.columns.append( TableColumn(column_name="user_id", type="INTEGER")) session.flush() # check that the column was added to the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 2 # delete the column in the original instance sqla_table.columns = sqla_table.columns[1:] session.flush() # check that the column was also removed from the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # modify the attribute in a column sqla_table.columns[0].is_dttm = True session.flush() # check that the dataset column was modified dataset = session.query(Dataset).one() assert dataset.columns[0].is_temporal is True
def test_is_time_druid_time_col(self): """Druid has a special __time column""" col = TableColumn(column_name="__time", type="INTEGER") self.assertEqual(col.is_dttm, None) DruidEngineSpec.alter_new_orm_column(col) self.assertEqual(col.is_dttm, True) col = TableColumn(column_name="__not_time", type="INTEGER") self.assertEqual(col.is_time, False)
def test_is_time_by_type(self): col = TableColumn(column_name="foo", type="DATE") self.assertEqual(col.is_time, True) col = TableColumn(column_name="foo", type="DATETIME") self.assertEqual(col.is_time, True) col = TableColumn(column_name="foo", type="STRING") self.assertEqual(col.is_time, False)
def test_is_time_by_type(self): col = TableColumn(column_name='foo', type='DATE') self.assertEquals(col.is_time, True) col = TableColumn(column_name='foo', type='DATETIME') self.assertEquals(col.is_time, True) col = TableColumn(column_name='foo', type='STRING') self.assertEquals(col.is_time, False)
def physical_dataset(): from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn example_database = get_example_database() engine = example_database.get_sqla_engine() # sqlite can only execute one statement at a time engine.execute(""" CREATE TABLE IF NOT EXISTS physical_dataset( col1 INTEGER, col2 VARCHAR(255), col3 DECIMAL(4,2), col4 VARCHAR(255), col5 TIMESTAMP ); """) engine.execute(""" INSERT INTO physical_dataset values (0, 'a', 1.0, NULL, '2000-01-01 00:00:00'), (1, 'b', 1.1, NULL, '2000-01-02 00:00:00'), (2, 'c', 1.2, NULL, '2000-01-03 00:00:00'), (3, 'd', 1.3, NULL, '2000-01-04 00:00:00'), (4, 'e', 1.4, NULL, '2000-01-05 00:00:00'), (5, 'f', 1.5, NULL, '2000-01-06 00:00:00'), (6, 'g', 1.6, NULL, '2000-01-07 00:00:00'), (7, 'h', 1.7, NULL, '2000-01-08 00:00:00'), (8, 'i', 1.8, NULL, '2000-01-09 00:00:00'), (9, 'j', 1.9, NULL, '2000-01-10 00:00:00'); """) dataset = SqlaTable( table_name="physical_dataset", database=example_database, ) TableColumn(column_name="col1", type="INTEGER", table=dataset) TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.merge(dataset) db.session.commit() yield dataset engine.execute(""" DROP TABLE physical_dataset; """) dataset = db.session.query(SqlaTable).filter_by( table_name="physical_dataset").all() for ds in dataset: db.session.delete(ds) db.session.commit()
def test_temporal_varchar(self): """Ensure a column with is_dttm set to true evaluates to is_temporal == True""" database = get_example_database() tbl = SqlaTable(table_name="test_tbl", database=database) col = TableColumn(column_name="ds", type="VARCHAR", table=tbl) # by default, VARCHAR should not be assumed to be temporal assert col.is_temporal is False # changing to `is_dttm = True`, calling `is_temporal` should return True col.is_dttm = True assert col.is_temporal is True
def test_is_time_druid_time_col(self): """Druid has a special __time column""" database = Database(database_name="druid_db", sqlalchemy_uri="druid://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="__time", type="INTEGER", table=tbl) self.assertEqual(col.is_dttm, None) DruidEngineSpec.alter_new_orm_column(col) self.assertEqual(col.is_dttm, True) col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl) self.assertEqual(col.is_temporal, False)
def create_chart(type, gamecode, column): template = importlib.import_module("superset.chartbuilder.templates."+type) TBL = ConnectorRegistry.sources['table'] tbl_name = chart_name_template.format(type=type, gamecode=gamecode, column=column) print('Creating table {} reference'.format(tbl_name)) tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.database_id = db.session.query(Database).filter_by(database_name=template.database_name).first().id tbl.sql = template.table_sql.format(gamecode=gamecode, column=column) db.session.merge(tbl) tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() slice = db.session.query(Slice).filter_by(datasource_id=tbl.id).first() if not slice: slice = Slice(datasource_id=tbl.id, slice_name=tbl_name, datasource_name=tbl_name, datasource_type='table', viz_type=template.viz_type, created_by_fk=1) if template.viz_type == 'line' : bucket_dt = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=template.time_column).first() if not bucket_dt: bucket_dt = TableColumn(table_id=tbl.id, column_name=template.time_column, is_dttm=1) db.session.merge(bucket_dt) metric_json = [] for metric in template.metrics : metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first() if not metric_obj: metric_obj = TableColumn(table_id=tbl.id, column_name=metric) db.session.merge(metric_obj) metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first() metric_json.append('{{"column": {{"id": {metric_id}, "column_name": "{metric}"}}, "label": "{metric}", "aggregate": "MAX", "expressionType": "SIMPLE"}}' \ .format(metric_id=metric_obj.id, metric=metric)) slice.params=''' {{"datasource": "{datasource}", "granularity_sqla": "{time_column}", "time_grain_sqla": "PT1M", "time_range": "Last week", "metrics": [ {metric_json} ]}} '''.format(datasource=str(tbl.id)+'__table', time_column=template.time_column, metric_json=',\n'.join(metric_json)) elif template.viz_type == 'table' : for metric in template.metrics : metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first() if not metric_obj: metric_obj = TableColumn(table_id=tbl.id, column_name=metric) db.session.merge(metric_obj) slice.params='{{"datasource": "{datasource}", "metrics": [], "all_columns": ["{metric_csv}"]}}' \ .format(datasource=str(tbl.id)+'__table', metric_csv='", "'.join(template.metrics)) db.session.merge(slice) db.session.commit()
def _add_column( self, connection: Connection, table: SqlaTable, column_name: str, column_type: str, ) -> None: """ Add new column to table :param connection: The connection to work with :param table: The SqlaTable :param column_name: The name of the column :param column_type: The type of the column """ column = Column(column_name, column_type) name = column.compile(column_name, dialect=table.database.get_sqla_engine().dialect) col_type = column.type.compile( table.database.get_sqla_engine().dialect) sql = text( f"ALTER TABLE {self._get_from_clause(table)} ADD {name} {col_type}" ) connection.execute(sql) table.columns.append( TableColumn(column_name=column_name, type=col_type))
def test_calculated_column_in_order_by_base_engine_spec(self): table = self.get_table(name="birth_names") TableColumn( column_name="gender_cc", type="VARCHAR(255)", table=table, expression=""" case when gender=true then "male" else "female" end """, ) table.database.sqlalchemy_uri = "sqlite://" query_obj = { "groupby": ["gender_cc"], "is_timeseries": False, "filter": [], "orderby": [["gender_cc", True]], } sql = table.get_query_str(query_obj) assert ("""ORDER BY case when gender=true then "male" else "female" end ASC;""" in sql)
def test_boolean_type_where_operators(self): table = self.get_table(name="birth_names") db.session.add( TableColumn( column_name="boolean_gender", expression="case when gender = 'boy' then True else False end", type="BOOLEAN", table=table, ) ) query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": ["boolean_gender"], "metrics": ["count"], "is_timeseries": False, "filter": [ { "col": "boolean_gender", "op": FilterOperator.IN, "val": ["true", "false"], } ], "extras": {}, } sqla_query = table.get_sqla_query(**query_obj) sql = table.database.compile_sqla_query(sqla_query.sqla_query) dialect = table.database.get_dialect() operand = "(true, false)" # override native_boolean=False behavior in MySQLCompiler # https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/base.py if not dialect.supports_native_boolean and dialect.name != "mysql": operand = "(1, 0)" self.assertIn(f"IN {operand}", sql)
def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): database_name = 'main' name = '{0}{1}'.format(NAME_PREFIX, name) params = {DBREF: id, 'database_name': database_name} dict_rep = { 'database_id': get_main_database(db.session).id, 'table_name': name, 'schema': schema, 'id': id, 'params': json.dumps(params), 'columns': [{ 'column_name': c } for c in cols_names], 'metrics': [{ 'metric_name': c } for c in metric_names], } table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params), ) for col_name in cols_names: table.columns.append(TableColumn(column_name=col_name)) for metric_name in metric_names: table.metrics.append(SqlMetric(metric_name=metric_name)) return table, dict_rep
def create_table( self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[] ): database_name = "main" name = "{0}{1}".format(NAME_PREFIX, name) params = {DBREF: id, "database_name": database_name} if cols_uuids is None: cols_uuids = [None] * len(cols_names) dict_rep = { "database_id": get_example_database().id, "table_name": name, "schema": schema, "id": id, "params": json.dumps(params), "columns": [ {"column_name": c, "uuid": u} for c, u in zip(cols_names, cols_uuids) ], "metrics": [{"metric_name": c, "expression": ""} for c in metric_names], } table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params) ) for col_name, uuid in zip(cols_names, cols_uuids): table.columns.append(TableColumn(column_name=col_name, uuid=uuid)) for metric_name in metric_names: table.metrics.append(SqlMetric(metric_name=metric_name, expression="")) return table, dict_rep
def decode_dashboards( # pylint: disable=too-many-return-statements o: Dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ from superset.connectors.druid.models import ( DruidCluster, DruidColumn, DruidDatasource, DruidMetric, ) if "__Dashboard__" in o: return Dashboard(**o["__Dashboard__"]) if "__Slice__" in o: return Slice(**o["__Slice__"]) if "__TableColumn__" in o: return TableColumn(**o["__TableColumn__"]) if "__SqlaTable__" in o: return SqlaTable(**o["__SqlaTable__"]) if "__SqlMetric__" in o: return SqlMetric(**o["__SqlMetric__"]) if "__DruidCluster__" in o: return DruidCluster(**o["__DruidCluster__"]) if "__DruidColumn__" in o: return DruidColumn(**o["__DruidColumn__"]) if "__DruidDatasource__" in o: return DruidDatasource(**o["__DruidDatasource__"]) if "__DruidMetric__" in o: return DruidMetric(**o["__DruidMetric__"]) if "__datetime__" in o: return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") return o
def test_delete_sqlatable(app_context: None, session: Session) -> None: """ Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. """ from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() datasets = session.query(Dataset).all() assert len(datasets) == 1 session.delete(sqla_table) session.flush() # test that dataset was also deleted datasets = session.query(Dataset).all() assert len(datasets) == 0
def _add_table_metrics(datasource: SqlaTable) -> None: # By accessing the attribute first, we make sure `datasource.columns` and # `datasource.metrics` are already loaded. Otherwise accessing them later # may trigger an unnecessary and unexpected `after_update` event. columns, metrics = datasource.columns, datasource.metrics if not any(col.column_name == "num_california" for col in columns): col_state = str(column("state").compile(db.engine)) col_num = str(column("num").compile(db.engine)) columns.append( TableColumn( column_name="num_california", expression= f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END", )) if not any(col.metric_name == "sum__num" for col in metrics): col = str(column("num").compile(db.engine)) metrics.append( SqlMetric(metric_name="sum__num", expression=f"SUM({col})")) for col in columns: if col.column_name == "ds": col.is_dttm = True break
def decode_dashboards(o): """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ import superset.models.core as models from superset.connectors.sqla.models import ( SqlaTable, SqlMetric, TableColumn, ) if '__Dashboard__' in o: d = models.Dashboard() d.__dict__.update(o['__Dashboard__']) return d elif '__Slice__' in o: d = models.Slice() d.__dict__.update(o['__Slice__']) return d elif '__TableColumn__' in o: d = TableColumn() d.__dict__.update(o['__TableColumn__']) return d elif '__SqlaTable__' in o: d = SqlaTable() d.__dict__.update(o['__SqlaTable__']) return d elif '__SqlMetric__' in o: d = SqlMetric() d.__dict__.update(o['__SqlMetric__']) return d elif '__datetime__' in o: return datetime.strptime(o['__datetime__'], '%Y-%m-%dT%H:%M:%S') else: return o
def test_fetch_metadata_for_updated_virtual_table(self): table = SqlaTable( table_name="updated_sql_table", database=get_example_database(), sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase", ) TableColumn(column_name="intcol", type="FLOAT", table=table) TableColumn(column_name="oldcol", type="INT", table=table) TableColumn( column_name="expr", expression="case when 1 then 1 else 0 end", type="INT", table=table, ) TableColumn( column_name="mycase", expression="case when 1 then 1 else 0 end", type="INT", table=table, ) # make sure the columns have been mapped properly assert len(table.columns) == 4 table.fetch_metadata(commit=False) # assert that the removed column has been dropped and # the physical and calculated columns are present assert {col.column_name for col in table.columns} == { "intcol", "strcol", "mycase", "expr", } cols: Dict[str, TableColumn] = { col.column_name: col for col in table.columns } # assert that the type for intcol has been updated (asserting CI types) backend = get_example_database().backend assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type) # assert that the expression has been replaced with the new physical column assert cols["mycase"].expression == "" assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type) assert cols["expr"].expression == "case when 1 then 1 else 0 end" db.session.delete(table)
def test_alter_new_orm_column(self): """ DB Eng Specs (crate): Test alter orm column """ database = Database(database_name="crate", sqlalchemy_uri="crate://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl) CrateEngineSpec.alter_new_orm_column(col) assert col.python_date_format == "epoch_ms"
def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]): params = {"remote_id": id, "database_name": "examples"} table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params) ) for col_name in cols_names: table.columns.append(TableColumn(column_name=col_name)) for metric_name in metric_names: table.metrics.append(SqlMetric(metric_name=metric_name, expression="")) return table
def virtual_dataset(): from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn dataset = SqlaTable( table_name="virtual_dataset", sql= ("SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 " "UNION ALL " "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' " "UNION ALL " "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' " "UNION ALL " "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' " "UNION ALL " "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' " "UNION ALL " "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' " "UNION ALL " "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' " "UNION ALL " "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' " "UNION ALL " "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' " "UNION ALL " "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "), database=get_example_database(), ) TableColumn(column_name="col1", type="INTEGER", table=dataset) TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) # Different database dialect datetime type is not consistent, so temporarily use varchar TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.merge(dataset) yield dataset db.session.delete(dataset) db.session.commit()
def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: from superset.connectors.sqla.models import TableColumn return { TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"): { "name": "ds", "expression": "ds", "type": "TIMESTAMP", "advanced_data_type": None, "is_temporal": True, "is_physical": True, }, TableColumn(column_name="num_boys", type="INTEGER", groupby=True): { "name": "num_boys", "expression": "num_boys", "type": "INTEGER", "advanced_data_type": None, "is_dimensional": True, "is_physical": True, }, TableColumn(column_name="region", type="VARCHAR", groupby=True): { "name": "region", "expression": "region", "type": "VARCHAR", "advanced_data_type": None, "is_dimensional": True, "is_physical": True, }, TableColumn( column_name="profit", type="INTEGER", groupby=False, expression="revenue-expenses", ): { "name": "profit", "expression": "revenue-expenses", "type": "INTEGER", "advanced_data_type": None, "is_physical": False, }, }
def virtual_dataset_comma_in_column_value(): from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn dataset = SqlaTable( table_name="virtual_dataset", sql=("SELECT 'col1,row1' as col1, 'col2, row1' as col2 " "UNION ALL " "SELECT 'col1,row2' as col1, 'col2, row2' as col2 " "UNION ALL " "SELECT 'col1,row3' as col1, 'col2, row3' as col2 "), database=get_example_database(), ) TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset) TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.merge(dataset) yield dataset db.session.delete(dataset) db.session.commit()
class SqlaConnectorTestCase(BaseConnectorTestCase): columns = [ TableColumn(column_name='region', type='VARCHAR(20)'), TableColumn(column_name='district', type='VARCHAR(20)'), TableColumn(column_name='project', type='VARCHAR(20)'), TableColumn(column_name='received', type='DATE', is_dttm=True), TableColumn(column_name='value', type='BIGINT'), ] metrics = [ SqlMetric(metric_name='sum__value', metric_type='sum', expression='SUM(value)'), SqlMetric(metric_name='avg__value', metric_type='avg', expression='AVG(value)'), SqlMetric(metric_name='ratio', metric_type='avg', expression='AVG(value/value2)'), SqlMetric(metric_name='value_percentage', metric_type='custom', expression='SUM(value)/SUM(value + value2)'), SqlMetric(metric_name='category_percentage', metric_type='custom', expression="SUM(CASE WHEN category='CategoryA' THEN 1 ELSE 0 END)/" 'CAST(COUNT(*) AS REAL)'), ] def setUp(self): super(SqlaConnectorTestCase, self).setUp() sqlalchemy_uri = 'sqlite:////tmp/test.db' database = Database( database_name='test_database', sqlalchemy_uri=sqlalchemy_uri) self.connection = database.get_sqla_engine().connect() self.datasource = SqlaTable(table_name='test_datasource', database=database, columns=self.columns, metrics=self.metrics) with database.get_sqla_engine().begin() as connection: self.df.to_sql(self.datasource.table_name, connection, if_exists='replace', index=False, dtype={'received': Date})
def create_table( self, name, schema='', id=0, cols_names=[], metric_names=[]): params = {'remote_id': id, 'database_name': 'main'} table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params) ) for col_name in cols_names: table.columns.append( TableColumn(column_name=col_name)) for metric_name in metric_names: table.metrics.append(SqlMetric(metric_name=metric_name)) return table
def test_jinja_metrics_and_calc_columns(self, flask_g): flask_g.user.username = "******" base_query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": ["user", "expr"], "metrics": [{ "expressionType": AdhocMetricExpressionType.SQL, "sqlExpression": "SUM(case when user = '******' " "then 1 else 0 end)", "label": "SUM(userid)", }], "is_timeseries": False, "filter": [], } table = SqlaTable( table_name="test_has_jinja_metric_and_expr", sql="SELECT '{{ current_username() }}' as user", database=get_example_database(), ) TableColumn( column_name="expr", expression="case when '{{ current_username() }}' = 'abc' " "then 'yes' else 'no' end", type="VARCHAR(100)", table=table, ) db.session.commit() sqla_query = table.get_sqla_query(**base_query_obj) query = table.database.compile_sqla_query(sqla_query.sqla_query) # assert expression assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query # assert metric assert "SUM(case when user = '******' then 1 else 0 end)" in query # Cleanup db.session.delete(table) db.session.commit()
def text_column_table(): with app.app_context(): table = SqlaTable( table_name="text_column_table", sql=("SELECT 'foo' as foo " "UNION SELECT '' " "UNION SELECT NULL " "UNION SELECT 'null' " "UNION SELECT '\"text in double quotes\"' " "UNION SELECT '''text in single quotes''' " "UNION SELECT 'double quotes \" in text' " "UNION SELECT 'single quotes '' in text' "), database=get_example_database(), ) TableColumn(column_name="foo", type="VARCHAR(255)", table=table) SqlMetric(metric_name="count", expression="count(*)", table=table) yield table