Exemplo n.º 1
0
def test_default_data_serialization():
    db_engine_spec = BaseEngineSpec()
    results = SupersetResultSet(SERIALIZATION_DATA, CURSOR_DESCR, db_engine_spec)

    with mock.patch.object(
        db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
    ) as expand_data:
        data = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True)
        expand_data.assert_called_once()
    assert isinstance(data[0], list)
Exemplo n.º 2
0
    def test_results_msgpack_deserialization(self):
        use_new_deserialization = True
        data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
        cursor_descr = (
            ("a", "string"),
            ("b", "int"),
            ("c", "float"),
            ("d", "datetime"),
        )
        db_engine_spec = BaseEngineSpec()
        results = SupersetResultSet(data, cursor_descr, db_engine_spec)
        query = {
            "database_id": 1,
            "sql": "SELECT * FROM birth_names LIMIT 100",
            "status": utils.QueryStatus.PENDING,
        }
        (
            serialized_data,
            selected_columns,
            all_columns,
            expanded_columns,
        ) = sql_lab._serialize_and_expand_data(
            results, db_engine_spec, use_new_deserialization
        )
        payload = {
            "query_id": 1,
            "status": utils.QueryStatus.SUCCESS,
            "state": utils.QueryStatus.SUCCESS,
            "data": serialized_data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query,
        }

        serialized_payload = sql_lab._serialize_payload(
            payload, use_new_deserialization
        )
        self.assertIsInstance(serialized_payload, bytes)

        with mock.patch.object(
            db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
        ) as expand_data:
            query_mock = mock.Mock()
            query_mock.database.db_engine_spec.expand_data = expand_data

            deserialized_payload = superset.views.utils._deserialize_results_payload(
                serialized_payload, query_mock, use_new_deserialization
            )
            df = results.to_pandas_df()
            payload["data"] = dataframe.df_to_records(df)

            self.assertDictEqual(deserialized_payload, payload)
            expand_data.assert_called_once()
Exemplo n.º 3
0
 def get_view_names(
     cls,
     database: "Database",
     inspector: Inspector,
     schema: Optional[str],
 ) -> List[str]:
     return BaseEngineSpec.get_view_names(
         database=database,
         inspector=inspector,
         schema=schema,
     )
Exemplo n.º 4
0
def test_get_text_clause_with_colon(app_context: AppContext) -> None:
    """
    Make sure text clauses are correctly escaped
    """

    from superset.db_engine_specs.base import BaseEngineSpec

    text_clause = BaseEngineSpec.get_text_clause(
        "SELECT foo FROM tbl WHERE foo = '123:456')"
    )
    assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
Exemplo n.º 5
0
 def test_pyodbc_rows_to_tuples(self):
     # Test for case when pyodbc.Row is returned (odbc driver)
     data = [
         Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
         Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
     ]
     expected = [
         (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
         (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
     ]
     result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
     self.assertListEqual(result, expected)
Exemplo n.º 6
0
 def test_get_table_names(self):
     inspector = mock.Mock()
     inspector.get_table_names = mock.Mock(
         return_value=["schema.table", "table_2"])
     inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
     """ Make sure base engine spec removes schema name from table name
     ie. when try_remove_schema_from_table_name == True. """
     base_result_expected = ["table", "table_2"]
     base_result = BaseEngineSpec.get_table_names(database=mock.ANY,
                                                  schema="schema",
                                                  inspector=inspector)
     self.assertListEqual(base_result_expected, base_result)
Exemplo n.º 7
0
def test_parse_sql_multi_statement(app_context: AppContext) -> None:
    """
    For string with multiple SQL-statements `parse_sql` method should return list
    where each element represents the single SQL-statement
    """

    from superset.db_engine_specs.base import BaseEngineSpec

    queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;")
    assert queries == [
        "SELECT foo FROM tbl1",
        "SELECT bar FROM tbl2",
    ]
Exemplo n.º 8
0
    def test_results_default_deserialization(self):
        use_new_deserialization = False
        data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
        cursor_descr = (
            ("a", "string"),
            ("b", "int"),
            ("c", "float"),
            ("d", "datetime"),
        )
        db_engine_spec = BaseEngineSpec()
        results = SupersetResultSet(data, cursor_descr, db_engine_spec)
        query = {
            "database_id": 1,
            "sql": "SELECT * FROM birth_names LIMIT 100",
            "status": utils.QueryStatus.PENDING,
        }
        (
            serialized_data,
            selected_columns,
            all_columns,
            expanded_columns,
        ) = sql_lab._serialize_and_expand_data(
            results, db_engine_spec, use_new_deserialization
        )
        payload = {
            "query_id": 1,
            "status": utils.QueryStatus.SUCCESS,
            "state": utils.QueryStatus.SUCCESS,
            "data": serialized_data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query,
        }

        serialized_payload = sql_lab._serialize_payload(
            payload, use_new_deserialization
        )
        self.assertIsInstance(serialized_payload, str)

        query_mock = mock.Mock()
        deserialized_payload = superset.views.utils._deserialize_results_payload(
            serialized_payload, query_mock, use_new_deserialization
        )

        self.assertDictEqual(deserialized_payload, payload)
        query_mock.assert_not_called()
Exemplo n.º 9
0
 def test_get_table_names(self):
     inspector = mock.Mock()
     inspector.get_table_names = mock.Mock(
         return_value=['schema.table', 'table_2'])
     inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])
     """ Make sure base engine spec removes schema name from table name
     ie. when try_remove_schema_from_table_name == True. """
     base_result_expected = ['table', 'table_2']
     base_result = BaseEngineSpec.get_table_names(schema='schema',
                                                  inspector=inspector)
     self.assertListEqual(base_result_expected, base_result)
     """ Make sure postgres doesn't try to remove schema name from table name
     ie. when try_remove_schema_from_table_name == False. """
     pg_result_expected = ['schema.table', 'table_2', 'table_3']
     pg_result = PostgresEngineSpec.get_table_names(schema='schema',
                                                    inspector=inspector)
     self.assertListEqual(pg_result_expected, pg_result)
Exemplo n.º 10
0
    def get_extra_params(database: "Database") -> Dict[str, Any]:
        """
        Some databases require adding elements to connection parameters,
        like passing certificates to `extra`. This can be done here.

        :param database: database instance from which to extract extras
        :raises CertificateException: If certificate is not valid/unparseable
        """
        extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
        engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
        connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})

        if database.server_cert:
            connect_args["http_scheme"] = "https"
            connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert)

        return extra
Exemplo n.º 11
0
    def test_new_data_serialization(self):
        data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16,
                                                660000))]
        cursor_descr = (
            ("a", "string"),
            ("b", "int"),
            ("c", "float"),
            ("d", "datetime"),
        )
        db_engine_spec = BaseEngineSpec()
        cdf = SupersetDataFrame(data, cursor_descr, db_engine_spec)

        with mock.patch.object(
                db_engine_spec, "expand_data",
                wraps=db_engine_spec.expand_data) as expand_data:
            data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
                cdf, db_engine_spec, True)
            expand_data.assert_not_called()

        self.assertIsInstance(data, bytes)
Exemplo n.º 12
0
    def test_msgpack_payload_serialization(self):
        use_new_deserialization = True
        data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16,
                                                660000))]
        cursor_descr = (
            ("a", "string"),
            ("b", "int"),
            ("c", "float"),
            ("d", "datetime"),
        )
        db_engine_spec = BaseEngineSpec()
        results = SupersetResultSet(data, cursor_descr, db_engine_spec)
        query = {
            "database_id": 1,
            "sql": "SELECT * FROM birth_names LIMIT 100",
            "status": QueryStatus.PENDING,
        }
        (
            serialized_data,
            selected_columns,
            all_columns,
            expanded_columns,
        ) = sql_lab._serialize_and_expand_data(results, db_engine_spec,
                                               use_new_deserialization)
        payload = {
            "query_id": 1,
            "status": QueryStatus.SUCCESS,
            "state": QueryStatus.SUCCESS,
            "data": serialized_data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query,
        }

        serialized = sql_lab._serialize_payload(payload,
                                                use_new_deserialization)
        self.assertIsInstance(serialized, bytes)
Exemplo n.º 13
0
 def select_star(
     cls,
     my_db,
     table_name: str,
     engine: Engine,
     schema: str = None,
     limit: int = 100,
     show_cols: bool = False,
     indent: bool = True,
     latest_partition: bool = True,
     cols: List[dict] = [],
 ) -> str:
     return BaseEngineSpec.select_star(
         my_db,
         table_name,
         engine,
         schema,
         limit,
         show_cols,
         indent,
         latest_partition,
         cols,
     )
Exemplo n.º 14
0
    def test_default_data_serialization(self):
        data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
        cursor_descr = (
            ("a", "string"),
            ("b", "int"),
            ("c", "float"),
            ("d", "datetime"),
        )
        db_engine_spec = BaseEngineSpec()
        results = SupersetResultSet(data, cursor_descr, db_engine_spec)

        with mock.patch.object(
            db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
        ) as expand_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True)
            expand_data.assert_called_once()

        self.assertIsInstance(data, list)
Exemplo n.º 15
0
 def select_star(
     cls,
     database,
     table_name: str,
     engine: Engine,
     schema: str = None,
     limit: int = 100,
     show_cols: bool = False,
     indent: bool = True,
     latest_partition: bool = True,
     cols: Optional[List[Dict[str, Any]]] = None,
 ) -> str:
     return BaseEngineSpec.select_star(
         database,
         table_name,
         engine,
         schema,
         limit,
         show_cols,
         indent,
         latest_partition,
         cols,
     )
 def test_get_datatype(self):
     self.assertEquals("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
Exemplo n.º 17
0
 def test_convert_dttm(self):
     dttm = self.get_dttm()
     self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
Exemplo n.º 18
0
 def get_all_datasource_names(cls, db, datasource_type: str) \
         -> List[utils.DatasourceName]:
     return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
Exemplo n.º 19
0
 def is_readonly(sql: str) -> bool:
     return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
Exemplo n.º 20
0
 def test_get_datatype(self):
     self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
     self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
     self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
     self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
Exemplo n.º 21
0
def test_cte_query_parsing(app_context: AppContext, original: TypeEngine,
                           expected: str) -> None:
    from superset.db_engine_specs.base import BaseEngineSpec

    actual = BaseEngineSpec.get_cte_query(original)
    assert actual == expected
Exemplo n.º 22
0
 def test_get_datatype(self):
     self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
Exemplo n.º 23
0
 def handle_cursor(cls, cursor: Any, query: Query,
                   session: Session) -> None:
     """Updates progress information"""
     BaseEngineSpec.handle_cursor(cursor=cursor,
                                  query=query,
                                  session=session)
Exemplo n.º 24
0
 def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
     return BaseEngineSpec._get_fields(cols)  # pylint: disable=protected-access
Exemplo n.º 25
0
 def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
     return BaseEngineSpec._get_fields(cols)
Exemplo n.º 26
0
 def get_all_datasource_names(
         cls, database: "Database",
         datasource_type: str) -> List[utils.DatasourceName]:
     return BaseEngineSpec.get_all_datasource_names(database,
                                                    datasource_type)
Exemplo n.º 27
0
 def get_view_names(cls, inspector, schema):
     return BaseEngineSpec.get_view_names(inspector, schema)