示例#1
0
def setup_presto_if_needed():
    backend = app.config["SQLALCHEMY_EXAMPLES_URI"].split("://")[0]
    database = get_example_database()
    extra = database.get_extra()

    if backend == "presto":
        # decrease poll interval for tests
        extra = {
            **extra,
            "engine_params": {
                "connect_args": {
                    "poll_interval": app.config["PRESTO_POLL_INTERVAL"]
                }
            },
        }
    else:
        # remove `poll_interval` from databases that do not support it
        extra = {**extra, "engine_params": {}}
    database.extra = json_dumps_w_dates(extra)
    db.session.commit()

    if backend in {"presto", "hive"}:
        database = get_example_database()
        engine = database.get_sqla_engine()
        drop_from_schema(engine, CTAS_SCHEMA_NAME)
        engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
        engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")

        drop_from_schema(engine, ADMIN_SCHEMA_NAME)
        engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
        engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
    def test_api_database(self):
        self.login("admin")
        self.create_fake_db()
        get_example_database()
        get_main_database()

        arguments = {
            "keys": [],
            "filters": [{
                "col": "expose_in_sqllab",
                "opr": "eq",
                "value": True
            }],
            "order_column": "database_name",
            "order_direction": "asc",
            "page": 0,
            "page_size": -1,
        }
        url = f"api/v1/database/?q={prison.dumps(arguments)}"

        self.assertEqual(
            {"examples", "fake_db_100", "main"},
            {
                r.get("database_name")
                for r in self.get_json_resp(url)["result"]
            },
        )
        self.delete_fake_db()
    def test_sqllab_table_viz(self):
        self.login("admin")
        examples_db = get_example_database()
        examples_db.get_sqla_engine().execute(
            "DROP TABLE IF EXISTS test_sqllab_table_viz")
        examples_db.get_sqla_engine().execute(
            "CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col")
        examples_dbid = examples_db.id

        payload = {
            "datasourceName": "test_sqllab_table_viz",
            "columns": [],
            "dbId": examples_dbid,
        }

        data = {"data": json.dumps(payload)}
        resp = self.get_json_resp("/superset/get_or_create_table/", data=data)
        self.assertIn("table_id", resp)

        # ensure owner is set correctly
        table_id = resp["table_id"]
        table = db.session.query(SqlaTable).filter_by(id=table_id).one()
        self.assertEqual([owner.username for owner in table.owners], ["admin"])
        db.session.delete(table)
        get_example_database().get_sqla_engine().execute(
            "DROP TABLE test_sqllab_table_viz")
        db.session.commit()
示例#4
0
    def test_get_query_no_data_access(self):
        """
        Query API: Test get query without data access
        """
        gamma1 = self.create_user(
            "gamma_1", "password", "Gamma", email="*****@*****.**"
        )
        gamma2 = self.create_user(
            "gamma_2", "password", "Gamma", email="*****@*****.**"
        )

        gamma1_client_id = self.get_random_string()
        gamma2_client_id = self.get_random_string()
        query_gamma1 = self.insert_query(
            get_example_database().id, gamma1.id, gamma1_client_id
        )
        query_gamma2 = self.insert_query(
            get_example_database().id, gamma2.id, gamma2_client_id
        )

        # Gamma1 user, only sees his own queries
        self.login(username="******", password="******")
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 404)
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 200)

        # Gamma2 user, only sees his own queries
        self.logout()
        self.login(username="******", password="******")
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 404)
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 200)

        # Admin's have the "all query access" permission
        self.logout()
        self.login(username="******")
        uri = f"api/v1/query/{query_gamma1.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 200)
        uri = f"api/v1/query/{query_gamma2.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 200)

        # rollback changes
        db.session.delete(query_gamma1)
        db.session.delete(query_gamma2)
        db.session.delete(gamma1)
        db.session.delete(gamma2)
        db.session.commit()
    def test_sql_json_schema_access(self):
        examples_db = get_example_database()
        db_backend = examples_db.backend
        if db_backend == "sqlite":
            # sqlite doesn't support database creation
            return

        sqllab_test_db_schema_permission_view = (
            security_manager.add_permission_view_menu(
                "schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]"
            )
        )
        schema_perm_role = security_manager.add_role("SchemaPermission")
        security_manager.add_permission_role(
            schema_perm_role, sqllab_test_db_schema_permission_view
        )
        self.create_user_with_roles(
            "SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
        )

        examples_db.get_sqla_engine().execute(
            f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
        )

        data = self.run_sql(
            f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="******"
        )
        self.assertEqual(1, len(data["data"]))

        data = self.run_sql(
            f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
            "4",
            username="******",
            schema=CTAS_SCHEMA_NAME,
        )
        self.assertEqual(1, len(data["data"]))

        # postgres needs a schema as a part of the table name.
        if db_backend == "mysql":
            data = self.run_sql(
                "SELECT * FROM test_table",
                "5",
                username="******",
                schema=CTAS_SCHEMA_NAME,
            )
            self.assertEqual(1, len(data["data"]))

        db.session.query(Query).delete()
        get_example_database().get_sqla_engine().execute(
            f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table"
        )
        db.session.commit()
    def test_sql_json_to_saved_query_info(self):
        """
        SQLLab: Test SQLLab query execution info propagation to saved queries
        """
        self.login("admin")

        sql_statement = "SELECT * FROM birth_names LIMIT 10"
        examples_db_id = get_example_database().id
        saved_query = SavedQuery(db_id=examples_db_id, sql=sql_statement)
        db.session.add(saved_query)
        db.session.commit()

        with freeze_time(datetime.now().isoformat(timespec="seconds")):
            self.run_sql(sql_statement, "1", username="******")
            saved_query_ = (
                db.session.query(SavedQuery)
                .filter(
                    SavedQuery.db_id == examples_db_id, SavedQuery.sql == sql_statement
                )
                .one_or_none()
            )
            assert saved_query_.rows is not None
            assert saved_query_.last_run == datetime.now()
        # Rollback changes
        db.session.delete(saved_query_)
        db.session.commit()
    def test_import_table_override(self):
        schema = get_example_default_schema()
        table = self.create_table(
            "table_override",
            id=10003,
            cols_names=["col1"],
            metric_names=["m1"],
            schema=schema,
        )
        db_id = get_example_database().id
        imported_id = import_dataset(table, db_id, import_time=1991)

        table_over = self.create_table(
            "table_override",
            id=10003,
            cols_names=["new_col1", "col2", "col3"],
            metric_names=["new_metric1"],
            schema=schema,
        )
        imported_over_id = import_dataset(table_over, db_id, import_time=1992)

        imported_over = self.get_table_by_id(imported_over_id)
        self.assertEqual(imported_id, imported_over.id)
        expected_table = self.create_table(
            "table_override",
            id=10003,
            metric_names=["new_metric1", "m1"],
            cols_names=["col1", "new_col1", "col2", "col3"],
            schema=schema,
        )
        self.assert_table_equals(expected_table, imported_over)
 def test_import_table_no_metadata(self):
     schema = get_example_default_schema()
     db_id = get_example_database().id
     table = self.create_table("pure_table", id=10001, schema=schema)
     imported_id = import_dataset(table, db_id, import_time=1989)
     imported = self.get_table_by_id(imported_id)
     self.assert_table_equals(table, imported)
示例#9
0
    def test_override_role_permissions_drops_absent_perms(self):
        database = get_example_database()
        engine = database.get_sqla_engine()
        schema = inspect(engine).default_schema_name

        override_me = security_manager.find_role("override_me")
        override_me.permissions.append(
            security_manager.find_permission_view_menu(
                view_menu_name=self.get_table(name="energy_usage").perm,
                permission_name="datasource_access",
            ))
        db.session.flush()

        perm_data = ROLE_TABLES_PERM_DATA.copy()
        perm_data["database"][0]["schema"][0]["name"] = schema

        response = self.client.post(
            "/superset/override_role_permissions/",
            data=json.dumps(perm_data),
            content_type="application/json",
        )
        self.assertEqual(201, response.status_code)
        updated_override_me = security_manager.find_role("override_me")
        self.assertEqual(1, len(updated_override_me.permissions))
        birth_names = self.get_table(name="birth_names")
        self.assertEqual(birth_names.perm,
                         updated_override_me.permissions[0].view_menu.name)
        self.assertEqual("datasource_access",
                         updated_override_me.permissions[0].permission.name)
示例#10
0
    def test_validate_sql_endpoint_mocked(self, get_validator_by_name):
        """Assert that, with a mocked validator, annotations make it back out
        from the validate_sql_json endpoint as a list of json dictionaries"""
        if get_example_database().backend == "hive":
            pytest.skip("Hive validator is not implemented")
        self.login("admin")

        validator = MagicMock()
        get_validator_by_name.return_value = validator
        validator.validate.return_value = [
            SQLValidationAnnotation(
                message="I don't know what I expected, but it wasn't this",
                line_number=4,
                start_column=12,
                end_column=42,
            )
        ]

        resp = self.validate_sql(
            "SELECT * FROM somewhere_over_the_rainbow",
            client_id="1",
            raise_on_error=False,
        )

        self.assertEqual(1, len(resp))
        self.assertIn("expected,", resp[0]["message"])
示例#11
0
    def create_table(
        self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[]
    ):
        database_name = "main"
        name = "{0}{1}".format(NAME_PREFIX, name)
        params = {DBREF: id, "database_name": database_name}

        if cols_uuids is None:
            cols_uuids = [None] * len(cols_names)

        dict_rep = {
            "database_id": get_example_database().id,
            "table_name": name,
            "schema": schema,
            "id": id,
            "params": json.dumps(params),
            "columns": [
                {"column_name": c, "uuid": u} for c, u in zip(cols_names, cols_uuids)
            ],
            "metrics": [{"metric_name": c, "expression": ""} for c in metric_names],
        }

        table = SqlaTable(
            id=id, schema=schema, table_name=name, params=json.dumps(params)
        )
        for col_name, uuid in zip(cols_names, cols_uuids):
            table.columns.append(TableColumn(column_name=col_name, uuid=uuid))
        for metric_name in metric_names:
            table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
        return table, dict_rep
示例#12
0
    def test_adhoc_metrics_and_calc_columns(self):
        base_query_obj = {
            "granularity":
            None,
            "from_dttm":
            None,
            "to_dttm":
            None,
            "groupby": ["user", "expr"],
            "metrics": [{
                "expressionType":
                AdhocMetricExpressionType.SQL,
                "sqlExpression":
                "(SELECT (SELECT * from birth_names) "
                "from test_validate_adhoc_sql)",
                "label":
                "adhoc_metrics",
            }],
            "is_timeseries":
            False,
            "filter": [],
        }

        table = SqlaTable(table_name="test_validate_adhoc_sql",
                          database=get_example_database())
        db.session.commit()

        with pytest.raises(SupersetSecurityException):
            table.get_sqla_query(**base_query_obj)
        # Cleanup
        db.session.delete(table)
        db.session.commit()
示例#13
0
    def test_set_perm_sqla_table_none(self):
        session = db.session
        table = SqlaTable(
            schema="tmp_schema",
            table_name="tmp_perm_table",
            # Setting database_id instead of database will skip permission creation
            database_id=get_example_database().id,
        )
        session.add(table)
        session.commit()

        stored_table = (
            session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
        )
        # Assert no permission is created
        self.assertIsNone(
            security_manager.find_permission_view_menu(
                "datasource_access", stored_table.perm
            )
        )
        # Assert no bogus permission is created
        self.assertIsNone(
            security_manager.find_permission_view_menu(
                "datasource_access", f"[None].[tmp_perm_table](id:{stored_table.id})"
            )
        )
        session.delete(table)
        session.commit()
示例#14
0
    def create_queries(self):
        with self.create_app().app_context():
            queries = []
            admin_id = self.get_user("admin").id
            alpha_id = self.get_user("alpha").id
            example_database_id = get_example_database().id
            main_database_id = get_main_database().id
            for cx in range(QUERIES_FIXTURE_COUNT - 1):
                queries.append(
                    self.insert_query(
                        example_database_id,
                        admin_id,
                        self.get_random_string(),
                        sql=f"SELECT col1, col2 from table{cx}",
                        rows=cx,
                        status=QueryStatus.SUCCESS if
                        (cx % 2) == 0 else QueryStatus.RUNNING,
                    ))
            queries.append(
                self.insert_query(
                    main_database_id,
                    alpha_id,
                    self.get_random_string(),
                    sql=f"SELECT col1, col2 from table{QUERIES_FIXTURE_COUNT}",
                    rows=QUERIES_FIXTURE_COUNT,
                    status=QueryStatus.SUCCESS,
                ))

            yield queries

            # rollback changes
            for query in queries:
                db.session.delete(query)
            db.session.commit()
示例#15
0
    def test_get_list_query_no_data_access(self):
        """
        Query API: Test get queries no data access
        """
        admin = self.get_user("admin")
        client_id = self.get_random_string()
        query = self.insert_query(
            get_example_database().id,
            admin.id,
            client_id,
            sql="SELECT col1, col2 from table1",
        )

        self.login(username="******")
        arguments = {
            "filters": [{
                "col": "sql",
                "opr": "sw",
                "value": "SELECT col1"
            }]
        }
        uri = f"api/v1/query/?q={prison.dumps(arguments)}"
        rv = self.client.get(uri)
        assert rv.status_code == 200
        data = json.loads(rv.data.decode("utf-8"))
        assert data["count"] == 0

        # rollback changes
        db.session.delete(query)
        db.session.commit()
 def test_schemas_accessible_by_user_admin(self, mock_g):
     mock_g.user = security_manager.find_user("admin")
     with self.client.application.test_request_context():
         database = get_example_database()
         schemas = security_manager.get_schemas_accessible_by_user(
             database, ["1", "2", "3"])
         self.assertEqual(schemas, ["1", "2", "3"])  # no changes
示例#17
0
    def test_create_saved_query(self):
        """
        Saved Query API: Test create
        """
        admin = self.get_user("admin")
        example_db = get_example_database()

        post_data = {
            "schema": "schema1",
            "label": "label1",
            "description": "some description",
            "sql": "SELECT col1, col2 from table1",
            "db_id": example_db.id,
        }

        self.login(username="******")
        uri = f"api/v1/saved_query/"
        rv = self.client.post(uri, json=post_data)
        data = json.loads(rv.data.decode("utf-8"))
        assert rv.status_code == 201

        saved_query_id = data.get("id")
        model = db.session.query(SavedQuery).get(saved_query_id)
        for key in post_data:
            assert getattr(model, key) == data["result"][key]

        # Rollback changes
        db.session.delete(model)
        db.session.commit()
示例#18
0
    def test_with_row_limit_and_offset__row_limit_and_offset_were_applied(
            self):
        """
        Chart data API: Test chart data query with limit and offset
        """
        self.query_context_payload["queries"][0]["row_limit"] = 5
        self.query_context_payload["queries"][0]["row_offset"] = 0
        self.query_context_payload["queries"][0]["orderby"] = [["name", True]]

        rv = self.post_assert_metric(CHART_DATA_URI,
                                     self.query_context_payload, "data")
        self.assert_row_count(rv, 5)
        result = rv.json["result"][0]

        # TODO: fix offset for presto DB
        if get_example_database().backend == "presto":
            return

        # ensure that offset works properly
        offset = 2
        expected_name = result["data"][offset]["name"]
        self.query_context_payload["queries"][0]["row_offset"] = offset
        rv = self.post_assert_metric(CHART_DATA_URI,
                                     self.query_context_payload, "data")
        result = rv.json["result"][0]
        assert result["rowcount"] == 5
        assert result["data"][0]["name"] == expected_name
    def test_export_dataset_command_key_order(self, mock_g):
        """Test that they keys in the YAML have the same order as export_fields"""
        mock_g.user = security_manager.find_user("admin")

        example_db = get_example_database()
        example_dataset = _get_table_from_list_by_name("energy_usage",
                                                       example_db.tables)
        command = ExportDatasetsCommand([example_dataset.id])
        contents = dict(command.run())

        metadata = yaml.safe_load(
            contents["datasets/examples/energy_usage.yaml"])
        assert list(metadata.keys()) == [
            "table_name",
            "main_dttm_col",
            "description",
            "default_endpoint",
            "offset",
            "cache_timeout",
            "schema",
            "sql",
            "params",
            "template_params",
            "filter_select_enabled",
            "fetch_values_predicate",
            "extra",
            "uuid",
            "metrics",
            "columns",
            "version",
            "database_uuid",
        ]
示例#20
0
    def test_override_role_permissions_druid_and_table(self):
        database = get_example_database()
        engine = database.get_sqla_engine()
        schema = inspect(engine).default_schema_name

        perm_data = ROLE_ALL_PERM_DATA.copy()
        perm_data["database"][0]["schema"][0]["name"] = schema
        response = self.client.post(
            "/superset/override_role_permissions/",
            data=json.dumps(ROLE_ALL_PERM_DATA),
            content_type="application/json",
        )
        self.assertEqual(201, response.status_code)

        updated_role = security_manager.find_role("override_me")
        perms = sorted(updated_role.permissions,
                       key=lambda p: p.view_menu.name)
        druid_ds_1 = self.get_druid_ds_by_name("druid_ds_1")
        self.assertEqual(druid_ds_1.perm, perms[0].view_menu.name)
        self.assertEqual("datasource_access", perms[0].permission.name)

        druid_ds_2 = self.get_druid_ds_by_name("druid_ds_2")
        self.assertEqual(druid_ds_2.perm, perms[1].view_menu.name)
        self.assertEqual("datasource_access",
                         updated_role.permissions[1].permission.name)

        birth_names = self.get_table(name="birth_names")
        self.assertEqual(birth_names.perm, perms[2].view_menu.name)
        self.assertEqual("datasource_access",
                         updated_role.permissions[2].permission.name)
        self.assertEqual(3, len(perms))
    def test_sqllab_viz(self):
        self.login("admin")
        examples_dbid = get_example_database().id
        payload = {
            "chartType": "dist_bar",
            "datasourceName": f"test_viz_flow_table_{random()}",
            "schema": "superset",
            "columns": [
                {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"},
                {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"},
            ],
            "sql": """\
                SELECT *
                FROM birth_names
                LIMIT 10""",
            "dbId": examples_dbid,
        }
        data = {"data": json.dumps(payload)}
        resp = self.get_json_resp("/superset/sqllab_viz/", data=data)
        self.assertIn("table_id", resp)

        # ensure owner is set correctly
        table_id = resp["table_id"]
        table = db.session.query(SqlaTable).filter_by(id=table_id).one()
        self.assertEqual([owner.username for owner in table.owners], ["admin"])
        view_menu = security_manager.find_view_menu(table.get_perm())
        assert view_menu is not None

        # Cleanup
        db.session.delete(table)
        db.session.commit()
示例#22
0
    def test_external_metadata_by_name_for_virtual_table(self):
        self.login(username="******")
        session = db.session
        table = SqlaTable(
            table_name="dummy_sql_table",
            database=get_example_database(),
            schema=get_example_default_schema(),
            sql="select 123 as intcol, 'abc' as strcol",
        )
        session.add(table)
        session.commit()

        tbl = self.get_table(name="dummy_sql_table")
        params = prison.dumps(
            {
                "datasource_type": "table",
                "database_name": tbl.database.database_name,
                "schema_name": tbl.schema,
                "table_name": tbl.table_name,
            }
        )
        url = f"/datasource/external_metadata_by_name/?q={params}"
        resp = self.get_json_resp(url)
        assert {o.get("name") for o in resp} == {"intcol", "strcol"}
        session.delete(tbl)
        session.commit()
    def test_get_timestamp_expression(self):
        col_type = (
            "VARCHAR"
            if get_example_database().backend == "presto"
            else "TemporalWrapperType"
        )
        tbl = self.get_table(name="birth_names")
        ds_col = tbl.get_column("ds")
        sqla_literal = ds_col.get_timestamp_expression(None)
        self.assertEqual(str(sqla_literal.compile()), "ds")
        assert type(sqla_literal.type).__name__ == col_type

        sqla_literal = ds_col.get_timestamp_expression("P1D")
        assert type(sqla_literal.type).__name__ == col_type
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(ds)")

        prev_ds_expr = ds_col.expression
        ds_col.expression = "DATE_ADD(ds, 1)"
        sqla_literal = ds_col.get_timestamp_expression("P1D")
        assert type(sqla_literal.type).__name__ == col_type
        compiled = "{}".format(sqla_literal.compile())
        if tbl.database.backend == "mysql":
            self.assertEqual(compiled, "DATE(DATE_ADD(ds, 1))")
        ds_col.expression = prev_ds_expr
示例#24
0
def setup_sample_data() -> Any:
    # TODO(john-bodley): Determine a cleaner way of setting up the sample data without
    # relying on `tests.integration_tests.test_app.app` leveraging an  `app` fixture which is purposely
    # scoped to the function level to ensure tests remain idempotent.
    with app.app_context():
        setup_presto_if_needed()

        from superset.cli.test import load_test_users_run

        load_test_users_run()

        from superset.examples.css_templates import load_css_templates

        load_css_templates()

    yield

    with app.app_context():
        engine = get_example_database().get_sqla_engine()

        # drop sqlachemy tables

        db.session.commit()
        from sqlalchemy.ext import declarative

        sqla_base = declarative.declarative_base()
        # uses sorted_tables to drop in proper order without violating foreign constrains
        for table in sqla_base.metadata.sorted_tables:
            table.__table__.drop()
        db.session.commit()
示例#25
0
        def __call__(self) -> Database:
            with app.app_context():
                if self._db is None:
                    self._db = get_example_database()
                    self._load_lazy_data_to_decouple_from_session()

                return self._db
示例#26
0
    def test_fetch_metadata_for_updated_virtual_table(self):
        table = SqlaTable(
            table_name="updated_sql_table",
            database=get_example_database(),
            sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase",
        )
        TableColumn(column_name="intcol", type="FLOAT", table=table)
        TableColumn(column_name="oldcol", type="INT", table=table)
        TableColumn(
            column_name="expr",
            expression="case when 1 then 1 else 0 end",
            type="INT",
            table=table,
        )
        TableColumn(
            column_name="mycase",
            expression="case when 1 then 1 else 0 end",
            type="INT",
            table=table,
        )

        # make sure the columns have been mapped properly
        assert len(table.columns) == 4
        table.fetch_metadata(commit=False)

        # assert that the removed column has been dropped and
        # the physical and calculated columns are present
        assert {col.column_name
                for col in table.columns} == {
                    "intcol",
                    "strcol",
                    "mycase",
                    "expr",
                }
        cols: Dict[str, TableColumn] = {
            col.column_name: col
            for col in table.columns
        }
        # assert that the type for intcol has been updated (asserting CI types)
        backend = get_example_database().backend
        assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type)
        # assert that the expression has been replaced with the new physical column
        assert cols["mycase"].expression == ""
        assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
        assert cols["expr"].expression == "case when 1 then 1 else 0 end"

        db.session.delete(table)
示例#27
0
    def test_get_query(self):
        """
        Query API: Test get query
        """
        admin = self.get_user("admin")
        client_id = self.get_random_string()
        example_db = get_example_database()
        query = self.insert_query(
            example_db.id,
            admin.id,
            client_id,
            sql="SELECT col1, col2 from table1",
            select_sql="SELECT col1, col2 from table1",
            executed_sql="SELECT col1, col2 from table1 LIMIT 100",
        )
        self.login(username="******")
        uri = f"api/v1/query/{query.id}"
        rv = self.client.get(uri)
        self.assertEqual(rv.status_code, 200)

        expected_result = {
            "database": {
                "id": example_db.id
            },
            "client_id": client_id,
            "end_result_backend_time": None,
            "error_message": None,
            "executed_sql": "SELECT col1, col2 from table1 LIMIT 100",
            "limit": 100,
            "progress": 100,
            "results_key": None,
            "rows": 100,
            "schema": None,
            "select_as_cta": None,
            "select_as_cta_used": False,
            "select_sql": "SELECT col1, col2 from table1",
            "sql": "SELECT col1, col2 from table1",
            "sql_editor_id": None,
            "status": "success",
            "tab_name": "",
            "tmp_schema_name": None,
            "tmp_table_name": None,
            "tracking_url": None,
        }
        data = json.loads(rv.data.decode("utf-8"))
        self.assertIn("changed_on", data["result"])
        for key, value in data["result"].items():
            # We can't assert timestamp
            if key not in (
                    "changed_on",
                    "end_time",
                    "start_running_time",
                    "start_time",
                    "id",
            ):
                self.assertEqual(value, expected_result[key])
        # rollback changes
        db.session.delete(query)
        db.session.commit()
示例#28
0
 def test_select_star(self):
     db = get_example_database()
     table_name = "energy_usage"
     sql = db.select_star(table_name, show_cols=False, latest_partition=False)
     quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
     expected = (
         textwrap.dedent(
             f"""\
     SELECT *
     FROM {quote(table_name)}
     LIMIT 100"""
         )
         if db.backend in {"presto", "hive"}
         else textwrap.dedent(
             f"""\
     SELECT *
     FROM {table_name}
     LIMIT 100"""
         )
     )
     assert expected in sql
     sql = db.select_star(table_name, show_cols=True, latest_partition=False)
     # TODO(bkyryliuk): unify sql generation
     if db.backend == "presto":
         assert (
             textwrap.dedent(
                 """\
             SELECT "source" AS "source",
                    "target" AS "target",
                    "value" AS "value"
             FROM "energy_usage"
             LIMIT 100"""
             )
             == sql
         )
     elif db.backend == "hive":
         assert (
             textwrap.dedent(
                 """\
             SELECT `source`,
                    `target`,
                    `value`
             FROM `energy_usage`
             LIMIT 100"""
             )
             == sql
         )
     else:
         assert (
             textwrap.dedent(
                 """\
             SELECT source,
                    target,
                    value
             FROM energy_usage
             LIMIT 100"""
             )
             in sql
         )
示例#29
0
def load_unicode_data():
    with app.app_context():
        _get_dataframe().to_sql(
            UNICODE_TBL_NAME,
            get_example_database().get_sqla_engine(),
            if_exists="replace",
            chunksize=500,
            dtype={"phrase": String(500)},
            index=False,
            method="multi",
            schema=get_example_default_schema(),
        )

    yield
    with app.app_context():
        engine = get_example_database().get_sqla_engine()
        engine.execute("DROP TABLE IF EXISTS unicode_test")
示例#30
0
def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard:
    table = create_table_metadata(UNICODE_TBL_NAME, get_example_database())
    table.fetch_metadata()

    if slice_title:
        slice = _create_and_commit_unicode_slice(table, slice_title)

    return create_dashboard("unicode-test", "Unicode Test", position, [slice])