Ejemplo n.º 1
0
    def test_db_column_types(self):
        test_cases: Dict[str, GenericDataType] = {
            # string
            "CHAR": GenericDataType.STRING,
            "VARCHAR": GenericDataType.STRING,
            "NVARCHAR": GenericDataType.STRING,
            "STRING": GenericDataType.STRING,
            "TEXT": GenericDataType.STRING,
            "NTEXT": GenericDataType.STRING,
            # numeric
            "INTEGER": GenericDataType.NUMERIC,
            "BIGINT": GenericDataType.NUMERIC,
            "DECIMAL": GenericDataType.NUMERIC,
            # temporal
            "DATE": GenericDataType.TEMPORAL,
            "DATETIME": GenericDataType.TEMPORAL,
            "TIME": GenericDataType.TEMPORAL,
            "TIMESTAMP": GenericDataType.TEMPORAL,
        }

        tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database())
        for str_type, db_col_type in test_cases.items():
            col = TableColumn(column_name="foo", type=str_type, table=tbl)
            self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL)
            self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC)
            self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING)

        for str_type, db_col_type in test_cases.items():
            col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True)
            self.assertTrue(col.is_temporal)
Ejemplo n.º 2
0
def test_table(session: Session) -> "SqlaTable":
    """
    Fixture that generates an in-memory table.
    """
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="id", type="INTEGER"),
        TableColumn(column_name="dttm", type="INTEGER"),
        TableColumn(column_name="duration_ms", type="INTEGER"),
    ]

    return SqlaTable(
        table_name="test_table",
        columns=columns,
        metrics=[],
        main_dttm_col=None,
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
Ejemplo n.º 3
0
def test_quote_expressions(app_context: None, session: Session) -> None:
    """
    Test that expressions are quoted appropriately in columns and datasets.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="has space", type="INTEGER"),
        TableColumn(column_name="no_need", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert dataset.expression == '"old dataset"'
    assert dataset.columns[0].expression == '"has space"'
    assert dataset.columns[1].expression == "no_need"
Ejemplo n.º 4
0
def test_update_sqlatable(mocker: MockFixture, app_context: None,
                          session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # add a column to the original ``SqlaTable`` instance
    sqla_table.columns.append(
        TableColumn(column_name="user_id", type="INTEGER"))
    session.flush()

    # check that the column was added to the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 2

    # delete the column in the original instance
    sqla_table.columns = sqla_table.columns[1:]
    session.flush()

    # check that the column was also removed from the dataset
    dataset = session.query(Dataset).one()
    assert len(dataset.columns) == 1

    # modify the attribute in a column
    sqla_table.columns[0].is_dttm = True
    session.flush()

    # check that the dataset column was modified
    dataset = session.query(Dataset).one()
    assert dataset.columns[0].is_temporal is True
Ejemplo n.º 5
0
    def test_is_time_druid_time_col(self):
        """Druid has a special __time column"""
        col = TableColumn(column_name="__time", type="INTEGER")
        self.assertEqual(col.is_dttm, None)
        DruidEngineSpec.alter_new_orm_column(col)
        self.assertEqual(col.is_dttm, True)

        col = TableColumn(column_name="__not_time", type="INTEGER")
        self.assertEqual(col.is_time, False)
Ejemplo n.º 6
0
    def test_is_time_by_type(self):
        col = TableColumn(column_name="foo", type="DATE")
        self.assertEqual(col.is_time, True)

        col = TableColumn(column_name="foo", type="DATETIME")
        self.assertEqual(col.is_time, True)

        col = TableColumn(column_name="foo", type="STRING")
        self.assertEqual(col.is_time, False)
    def test_is_time_by_type(self):
        col = TableColumn(column_name='foo', type='DATE')
        self.assertEquals(col.is_time, True)

        col = TableColumn(column_name='foo', type='DATETIME')
        self.assertEquals(col.is_time, True)

        col = TableColumn(column_name='foo', type='STRING')
        self.assertEquals(col.is_time, False)
Ejemplo n.º 8
0
def physical_dataset():
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

    example_database = get_example_database()
    engine = example_database.get_sqla_engine()
    # sqlite can only execute one statement at a time
    engine.execute("""
        CREATE TABLE IF NOT EXISTS physical_dataset(
          col1 INTEGER,
          col2 VARCHAR(255),
          col3 DECIMAL(4,2),
          col4 VARCHAR(255),
          col5 TIMESTAMP
        );
        """)
    engine.execute("""
        INSERT INTO physical_dataset values
        (0, 'a', 1.0, NULL, '2000-01-01 00:00:00'),
        (1, 'b', 1.1, NULL, '2000-01-02 00:00:00'),
        (2, 'c', 1.2, NULL, '2000-01-03 00:00:00'),
        (3, 'd', 1.3, NULL, '2000-01-04 00:00:00'),
        (4, 'e', 1.4, NULL, '2000-01-05 00:00:00'),
        (5, 'f', 1.5, NULL, '2000-01-06 00:00:00'),
        (6, 'g', 1.6, NULL, '2000-01-07 00:00:00'),
        (7, 'h', 1.7, NULL, '2000-01-08 00:00:00'),
        (8, 'i', 1.8, NULL, '2000-01-09 00:00:00'),
        (9, 'j', 1.9, NULL, '2000-01-10 00:00:00');
    """)

    dataset = SqlaTable(
        table_name="physical_dataset",
        database=example_database,
    )
    TableColumn(column_name="col1", type="INTEGER", table=dataset)
    TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
    TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
    TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
    TableColumn(column_name="col5",
                type="TIMESTAMP",
                is_dttm=True,
                table=dataset)
    SqlMetric(metric_name="count", expression="count(*)", table=dataset)
    db.session.merge(dataset)
    db.session.commit()

    yield dataset

    engine.execute("""
        DROP TABLE physical_dataset;
    """)
    dataset = db.session.query(SqlaTable).filter_by(
        table_name="physical_dataset").all()
    for ds in dataset:
        db.session.delete(ds)
    db.session.commit()
Ejemplo n.º 9
0
    def test_temporal_varchar(self):
        """Ensure a column with is_dttm set to true evaluates to is_temporal == True"""

        database = get_example_database()
        tbl = SqlaTable(table_name="test_tbl", database=database)
        col = TableColumn(column_name="ds", type="VARCHAR", table=tbl)
        # by default, VARCHAR should not be assumed to be temporal
        assert col.is_temporal is False
        # changing to `is_dttm = True`, calling `is_temporal` should return True
        col.is_dttm = True
        assert col.is_temporal is True
Ejemplo n.º 10
0
    def test_is_time_druid_time_col(self):
        """Druid has a special __time column"""

        database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
        tbl = SqlaTable(table_name="druid_tbl", database=database)
        col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
        self.assertEqual(col.is_dttm, None)
        DruidEngineSpec.alter_new_orm_column(col)
        self.assertEqual(col.is_dttm, True)

        col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
        self.assertEqual(col.is_temporal, False)
Ejemplo n.º 11
0
def create_chart(type, gamecode, column):
  template = importlib.import_module("superset.chartbuilder.templates."+type)

  TBL = ConnectorRegistry.sources['table']
  tbl_name = chart_name_template.format(type=type, gamecode=gamecode, column=column)

  print('Creating table {} reference'.format(tbl_name))
  tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
  if not tbl: tbl = TBL(table_name=tbl_name)
  tbl.database_id = db.session.query(Database).filter_by(database_name=template.database_name).first().id
  tbl.sql = template.table_sql.format(gamecode=gamecode, column=column)
  db.session.merge(tbl)

  tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()

  slice = db.session.query(Slice).filter_by(datasource_id=tbl.id).first()
  if not slice: slice = Slice(datasource_id=tbl.id, slice_name=tbl_name, datasource_name=tbl_name, datasource_type='table', viz_type=template.viz_type, created_by_fk=1)

  if template.viz_type == 'line' :
    bucket_dt = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=template.time_column).first()
    if not bucket_dt: bucket_dt = TableColumn(table_id=tbl.id, column_name=template.time_column, is_dttm=1)
    db.session.merge(bucket_dt)

    metric_json = []
    for metric in template.metrics :
      metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first()
      if not metric_obj:
        metric_obj = TableColumn(table_id=tbl.id, column_name=metric)
        db.session.merge(metric_obj)
        metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first()
      metric_json.append('{{"column": {{"id": {metric_id}, "column_name": "{metric}"}}, "label": "{metric}", "aggregate": "MAX", "expressionType": "SIMPLE"}}' \
                         .format(metric_id=metric_obj.id, metric=metric))

    slice.params='''
        {{"datasource": "{datasource}", "granularity_sqla": "{time_column}", "time_grain_sqla": "PT1M", "time_range": "Last week", 
        "metrics": [
        {metric_json}
        ]}}
        '''.format(datasource=str(tbl.id)+'__table', time_column=template.time_column, metric_json=',\n'.join(metric_json))

  elif template.viz_type == 'table' :
    for metric in template.metrics :
      metric_obj = db.session.query(TableColumn).filter_by(table_id=tbl.id, column_name=metric).first()
      if not metric_obj:
        metric_obj = TableColumn(table_id=tbl.id, column_name=metric)
        db.session.merge(metric_obj)

    slice.params='{{"datasource": "{datasource}", "metrics": [], "all_columns": ["{metric_csv}"]}}' \
      .format(datasource=str(tbl.id)+'__table', metric_csv='", "'.join(template.metrics))

  db.session.merge(slice)
  db.session.commit()
Ejemplo n.º 12
0
    def _add_column(
        self,
        connection: Connection,
        table: SqlaTable,
        column_name: str,
        column_type: str,
    ) -> None:
        """
        Add new column to table
        :param connection: The connection to work with
        :param table: The SqlaTable
        :param column_name: The name of the column
        :param column_type: The type of the column
        """
        column = Column(column_name, column_type)
        name = column.compile(column_name,
                              dialect=table.database.get_sqla_engine().dialect)
        col_type = column.type.compile(
            table.database.get_sqla_engine().dialect)
        sql = text(
            f"ALTER TABLE {self._get_from_clause(table)} ADD {name} {col_type}"
        )
        connection.execute(sql)

        table.columns.append(
            TableColumn(column_name=column_name, type=col_type))
Ejemplo n.º 13
0
    def test_calculated_column_in_order_by_base_engine_spec(self):
        table = self.get_table(name="birth_names")
        TableColumn(
            column_name="gender_cc",
            type="VARCHAR(255)",
            table=table,
            expression="""
            case
              when gender=true then "male"
              else "female"
            end
            """,
        )

        table.database.sqlalchemy_uri = "sqlite://"
        query_obj = {
            "groupby": ["gender_cc"],
            "is_timeseries": False,
            "filter": [],
            "orderby": [["gender_cc", True]],
        }
        sql = table.get_query_str(query_obj)
        assert ("""ORDER BY case
             when gender=true then "male"
             else "female"
         end ASC;""" in sql)
Ejemplo n.º 14
0
 def test_boolean_type_where_operators(self):
     table = self.get_table(name="birth_names")
     db.session.add(
         TableColumn(
             column_name="boolean_gender",
             expression="case when gender = 'boy' then True else False end",
             type="BOOLEAN",
             table=table,
         )
     )
     query_obj = {
         "granularity": None,
         "from_dttm": None,
         "to_dttm": None,
         "groupby": ["boolean_gender"],
         "metrics": ["count"],
         "is_timeseries": False,
         "filter": [
             {
                 "col": "boolean_gender",
                 "op": FilterOperator.IN,
                 "val": ["true", "false"],
             }
         ],
         "extras": {},
     }
     sqla_query = table.get_sqla_query(**query_obj)
     sql = table.database.compile_sqla_query(sqla_query.sqla_query)
     dialect = table.database.get_dialect()
     operand = "(true, false)"
     # override native_boolean=False behavior in MySQLCompiler
     # https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/base.py
     if not dialect.supports_native_boolean and dialect.name != "mysql":
         operand = "(1, 0)"
     self.assertIn(f"IN {operand}", sql)
Ejemplo n.º 15
0
    def create_table(self,
                     name,
                     schema='',
                     id=0,
                     cols_names=[],
                     metric_names=[]):
        database_name = 'main'
        name = '{0}{1}'.format(NAME_PREFIX, name)
        params = {DBREF: id, 'database_name': database_name}

        dict_rep = {
            'database_id': get_main_database(db.session).id,
            'table_name': name,
            'schema': schema,
            'id': id,
            'params': json.dumps(params),
            'columns': [{
                'column_name': c
            } for c in cols_names],
            'metrics': [{
                'metric_name': c
            } for c in metric_names],
        }

        table = SqlaTable(
            id=id,
            schema=schema,
            table_name=name,
            params=json.dumps(params),
        )
        for col_name in cols_names:
            table.columns.append(TableColumn(column_name=col_name))
        for metric_name in metric_names:
            table.metrics.append(SqlMetric(metric_name=metric_name))
        return table, dict_rep
Ejemplo n.º 16
0
    def create_table(
        self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[]
    ):
        database_name = "main"
        name = "{0}{1}".format(NAME_PREFIX, name)
        params = {DBREF: id, "database_name": database_name}

        if cols_uuids is None:
            cols_uuids = [None] * len(cols_names)

        dict_rep = {
            "database_id": get_example_database().id,
            "table_name": name,
            "schema": schema,
            "id": id,
            "params": json.dumps(params),
            "columns": [
                {"column_name": c, "uuid": u} for c, u in zip(cols_names, cols_uuids)
            ],
            "metrics": [{"metric_name": c, "expression": ""} for c in metric_names],
        }

        table = SqlaTable(
            id=id, schema=schema, table_name=name, params=json.dumps(params)
        )
        for col_name, uuid in zip(cols_names, cols_uuids):
            table.columns.append(TableColumn(column_name=col_name, uuid=uuid))
        for metric_name in metric_names:
            table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
        return table, dict_rep
Ejemplo n.º 17
0
def decode_dashboards(  # pylint: disable=too-many-return-statements
        o: Dict[str, Any]) -> Any:
    """
    Function to be passed into json.loads obj_hook parameter
    Recreates the dashboard object from a json representation.
    """
    from superset.connectors.druid.models import (
        DruidCluster,
        DruidColumn,
        DruidDatasource,
        DruidMetric,
    )

    if "__Dashboard__" in o:
        return Dashboard(**o["__Dashboard__"])
    if "__Slice__" in o:
        return Slice(**o["__Slice__"])
    if "__TableColumn__" in o:
        return TableColumn(**o["__TableColumn__"])
    if "__SqlaTable__" in o:
        return SqlaTable(**o["__SqlaTable__"])
    if "__SqlMetric__" in o:
        return SqlMetric(**o["__SqlMetric__"])
    if "__DruidCluster__" in o:
        return DruidCluster(**o["__DruidCluster__"])
    if "__DruidColumn__" in o:
        return DruidColumn(**o["__DruidColumn__"])
    if "__DruidDatasource__" in o:
        return DruidDatasource(**o["__DruidDatasource__"])
    if "__DruidMetric__" in o:
        return DruidMetric(**o["__DruidMetric__"])
    if "__datetime__" in o:
        return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S")

    return o
Ejemplo n.º 18
0
def test_delete_sqlatable(app_context: None, session: Session) -> None:
    """
    Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    datasets = session.query(Dataset).all()
    assert len(datasets) == 1

    session.delete(sqla_table)
    session.flush()

    # test that dataset was also deleted
    datasets = session.query(Dataset).all()
    assert len(datasets) == 0
Ejemplo n.º 19
0
def _add_table_metrics(datasource: SqlaTable) -> None:
    # By accessing the attribute first, we make sure `datasource.columns` and
    # `datasource.metrics` are already loaded. Otherwise accessing them later
    # may trigger an unnecessary and unexpected `after_update` event.
    columns, metrics = datasource.columns, datasource.metrics

    if not any(col.column_name == "num_california" for col in columns):
        col_state = str(column("state").compile(db.engine))
        col_num = str(column("num").compile(db.engine))
        columns.append(
            TableColumn(
                column_name="num_california",
                expression=
                f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
            ))

    if not any(col.metric_name == "sum__num" for col in metrics):
        col = str(column("num").compile(db.engine))
        metrics.append(
            SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))

    for col in columns:
        if col.column_name == "ds":
            col.is_dttm = True
            break
Ejemplo n.º 20
0
def decode_dashboards(o):
    """
    Function to be passed into json.loads obj_hook parameter
    Recreates the dashboard object from a json representation.
    """
    import superset.models.core as models
    from superset.connectors.sqla.models import (
        SqlaTable, SqlMetric, TableColumn,
    )

    if '__Dashboard__' in o:
        d = models.Dashboard()
        d.__dict__.update(o['__Dashboard__'])
        return d
    elif '__Slice__' in o:
        d = models.Slice()
        d.__dict__.update(o['__Slice__'])
        return d
    elif '__TableColumn__' in o:
        d = TableColumn()
        d.__dict__.update(o['__TableColumn__'])
        return d
    elif '__SqlaTable__' in o:
        d = SqlaTable()
        d.__dict__.update(o['__SqlaTable__'])
        return d
    elif '__SqlMetric__' in o:
        d = SqlMetric()
        d.__dict__.update(o['__SqlMetric__'])
        return d
    elif '__datetime__' in o:
        return datetime.strptime(o['__datetime__'], '%Y-%m-%dT%H:%M:%S')
    else:
        return o
Ejemplo n.º 21
0
    def test_fetch_metadata_for_updated_virtual_table(self):
        table = SqlaTable(
            table_name="updated_sql_table",
            database=get_example_database(),
            sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase",
        )
        TableColumn(column_name="intcol", type="FLOAT", table=table)
        TableColumn(column_name="oldcol", type="INT", table=table)
        TableColumn(
            column_name="expr",
            expression="case when 1 then 1 else 0 end",
            type="INT",
            table=table,
        )
        TableColumn(
            column_name="mycase",
            expression="case when 1 then 1 else 0 end",
            type="INT",
            table=table,
        )

        # make sure the columns have been mapped properly
        assert len(table.columns) == 4
        table.fetch_metadata(commit=False)

        # assert that the removed column has been dropped and
        # the physical and calculated columns are present
        assert {col.column_name
                for col in table.columns} == {
                    "intcol",
                    "strcol",
                    "mycase",
                    "expr",
                }
        cols: Dict[str, TableColumn] = {
            col.column_name: col
            for col in table.columns
        }
        # assert that the type for intcol has been updated (asserting CI types)
        backend = get_example_database().backend
        assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type)
        # assert that the expression has been replaced with the new physical column
        assert cols["mycase"].expression == ""
        assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
        assert cols["expr"].expression == "case when 1 then 1 else 0 end"

        db.session.delete(table)
Ejemplo n.º 22
0
 def test_alter_new_orm_column(self):
     """
     DB Eng Specs (crate): Test alter orm column
     """
     database = Database(database_name="crate", sqlalchemy_uri="crate://db")
     tbl = SqlaTable(table_name="druid_tbl", database=database)
     col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
     CrateEngineSpec.alter_new_orm_column(col)
     assert col.python_date_format == "epoch_ms"
 def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]):
     params = {"remote_id": id, "database_name": "examples"}
     table = SqlaTable(
         id=id, schema=schema, table_name=name, params=json.dumps(params)
     )
     for col_name in cols_names:
         table.columns.append(TableColumn(column_name=col_name))
     for metric_name in metric_names:
         table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
     return table
Ejemplo n.º 24
0
def virtual_dataset():
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

    dataset = SqlaTable(
        table_name="virtual_dataset",
        sql=
        ("SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 "
         "UNION ALL "
         "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' "
         "UNION ALL "
         "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' "
         "UNION ALL "
         "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' "
         "UNION ALL "
         "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' "
         "UNION ALL "
         "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' "
         "UNION ALL "
         "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' "
         "UNION ALL "
         "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' "
         "UNION ALL "
         "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' "
         "UNION ALL "
         "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "),
        database=get_example_database(),
    )
    TableColumn(column_name="col1", type="INTEGER", table=dataset)
    TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
    TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
    TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
    # Different database dialect datetime type is not consistent, so temporarily use varchar
    TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)

    SqlMetric(metric_name="count", expression="count(*)", table=dataset)
    db.session.merge(dataset)

    yield dataset

    db.session.delete(dataset)
    db.session.commit()
Ejemplo n.º 25
0
def sample_columns() -> Dict["TableColumn", Dict[str, Any]]:
    from superset.connectors.sqla.models import TableColumn

    return {
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"): {
            "name": "ds",
            "expression": "ds",
            "type": "TIMESTAMP",
            "advanced_data_type": None,
            "is_temporal": True,
            "is_physical": True,
        },
        TableColumn(column_name="num_boys", type="INTEGER", groupby=True): {
            "name": "num_boys",
            "expression": "num_boys",
            "type": "INTEGER",
            "advanced_data_type": None,
            "is_dimensional": True,
            "is_physical": True,
        },
        TableColumn(column_name="region", type="VARCHAR", groupby=True): {
            "name": "region",
            "expression": "region",
            "type": "VARCHAR",
            "advanced_data_type": None,
            "is_dimensional": True,
            "is_physical": True,
        },
        TableColumn(
            column_name="profit",
            type="INTEGER",
            groupby=False,
            expression="revenue-expenses",
        ): {
            "name": "profit",
            "expression": "revenue-expenses",
            "type": "INTEGER",
            "advanced_data_type": None,
            "is_physical": False,
        },
    }
Ejemplo n.º 26
0
def virtual_dataset_comma_in_column_value():
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

    dataset = SqlaTable(
        table_name="virtual_dataset",
        sql=("SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
             "UNION ALL "
             "SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
             "UNION ALL "
             "SELECT 'col1,row3' as col1, 'col2, row3' as col2 "),
        database=get_example_database(),
    )
    TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
    TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)

    SqlMetric(metric_name="count", expression="count(*)", table=dataset)
    db.session.merge(dataset)

    yield dataset

    db.session.delete(dataset)
    db.session.commit()
class SqlaConnectorTestCase(BaseConnectorTestCase):
    columns = [
        TableColumn(column_name='region', type='VARCHAR(20)'),
        TableColumn(column_name='district', type='VARCHAR(20)'),
        TableColumn(column_name='project', type='VARCHAR(20)'),
        TableColumn(column_name='received', type='DATE', is_dttm=True),
        TableColumn(column_name='value', type='BIGINT'),
    ]
    metrics = [
        SqlMetric(metric_name='sum__value', metric_type='sum',
                  expression='SUM(value)'),
        SqlMetric(metric_name='avg__value', metric_type='avg',
                  expression='AVG(value)'),
        SqlMetric(metric_name='ratio', metric_type='avg',
                  expression='AVG(value/value2)'),
        SqlMetric(metric_name='value_percentage', metric_type='custom',
                  expression='SUM(value)/SUM(value + value2)'),
        SqlMetric(metric_name='category_percentage', metric_type='custom',
                  expression="SUM(CASE WHEN category='CategoryA' THEN 1 ELSE 0 END)/"
                             'CAST(COUNT(*) AS REAL)'),
    ]
    def setUp(self):
        super(SqlaConnectorTestCase, self).setUp()
        sqlalchemy_uri = 'sqlite:////tmp/test.db'
        database = Database(
            database_name='test_database',
            sqlalchemy_uri=sqlalchemy_uri)
        self.connection = database.get_sqla_engine().connect()
        self.datasource = SqlaTable(table_name='test_datasource',
                                    database=database,
                                    columns=self.columns,
                                    metrics=self.metrics)
        with database.get_sqla_engine().begin() as connection:
            self.df.to_sql(self.datasource.table_name,
                           connection,
                           if_exists='replace',
                           index=False,
                           dtype={'received': Date})
Ejemplo n.º 28
0
 def create_table(
         self, name, schema='', id=0, cols_names=[], metric_names=[]):
     params = {'remote_id': id, 'database_name': 'main'}
     table = SqlaTable(
         id=id,
         schema=schema,
         table_name=name,
         params=json.dumps(params)
     )
     for col_name in cols_names:
         table.columns.append(
             TableColumn(column_name=col_name))
     for metric_name in metric_names:
         table.metrics.append(SqlMetric(metric_name=metric_name))
     return table
Ejemplo n.º 29
0
    def test_jinja_metrics_and_calc_columns(self, flask_g):
        flask_g.user.username = "******"
        base_query_obj = {
            "granularity":
            None,
            "from_dttm":
            None,
            "to_dttm":
            None,
            "groupby": ["user", "expr"],
            "metrics": [{
                "expressionType":
                AdhocMetricExpressionType.SQL,
                "sqlExpression":
                "SUM(case when user = '******' "
                "then 1 else 0 end)",
                "label":
                "SUM(userid)",
            }],
            "is_timeseries":
            False,
            "filter": [],
        }

        table = SqlaTable(
            table_name="test_has_jinja_metric_and_expr",
            sql="SELECT '{{ current_username() }}' as user",
            database=get_example_database(),
        )
        TableColumn(
            column_name="expr",
            expression="case when '{{ current_username() }}' = 'abc' "
            "then 'yes' else 'no' end",
            type="VARCHAR(100)",
            table=table,
        )
        db.session.commit()

        sqla_query = table.get_sqla_query(**base_query_obj)
        query = table.database.compile_sqla_query(sqla_query.sqla_query)
        # assert expression
        assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
        # assert metric
        assert "SUM(case when user = '******' then 1 else 0 end)" in query
        # Cleanup
        db.session.delete(table)
        db.session.commit()
Ejemplo n.º 30
0
def text_column_table():
    with app.app_context():
        table = SqlaTable(
            table_name="text_column_table",
            sql=("SELECT 'foo' as foo "
                 "UNION SELECT '' "
                 "UNION SELECT NULL "
                 "UNION SELECT 'null' "
                 "UNION SELECT '\"text in double quotes\"' "
                 "UNION SELECT '''text in single quotes''' "
                 "UNION SELECT 'double quotes \" in text' "
                 "UNION SELECT 'single quotes '' in text' "),
            database=get_example_database(),
        )
        TableColumn(column_name="foo", type="VARCHAR(255)", table=table)
        SqlMetric(metric_name="count", expression="count(*)", table=table)
        yield table