def test_extract_tables_mixed_from_clause() -> None:
    """
    Test that the parser handles a ``FROM`` clause with table and subselect.
    """
    assert (extract_tables("""
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
""") == {Table("table_a"),
         Table("table_b"),
         Table("table_c")})
def test_extract_tables_select_in_expression() -> None:
    """
    Test that parser works with ``SELECT``s used as expressions.
    """
    assert extract_tables("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
        Table("t1"),
        Table("t2"),
    }
    assert extract_tables("SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1") == {
        Table("t1"),
        Table("t2"),
    }
def test_extract_tables_multistatement() -> None:
    """
    Test that the parser works with multiple statements.
    """
    assert extract_tables("SELECT * FROM t1; SELECT * FROM t2") == {
        Table("t1"),
        Table("t2"),
    }
    assert extract_tables("SELECT * FROM t1; SELECT * FROM t2;") == {
        Table("t1"),
        Table("t2"),
    }
def test_table() -> None:
    """
    Test the ``Table`` class and its string conversion.

    Special characters in the table, schema, or catalog name should be escaped correctly.
    """
    assert str(Table("tbname")) == "tbname"
    assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
    assert (str(Table("tbname", "schemaname",
                      "catalogname")) == "catalogname.schemaname.tbname")
    assert (str(
        Table("table.name", "schema/name",
              "catalog\nname")) == "catalog%0Aname.schema%2Fname.table%2Ename")
示例#5
0
    def test_table(self):
        self.assertEqual(str(Table("tbname")), "tbname")
        self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")

        self.assertEqual(
            str(Table("tbname", "schemaname", "catalogname")),
            "catalogname.schemaname.tbname",
        )

        self.assertEqual(
            str(Table("tb.name", "schema/name", "catalog\name")),
            "catalog%0Aame.schema%2Fname.tb%2Ename",
        )
示例#6
0
def test_extract_table_references(mocker: MockerFixture) -> None:
    """
    Test the ``extract_table_references`` helper function.
    """
    assert extract_table_references("SELECT 1", "trino") == set()
    assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
        Table(table="some_table", schema=None, catalog=None)
    }
    assert extract_table_references(
        "SELECT {{ jinja }} FROM some_table",
        "trino") == {Table(table="some_table", schema=None, catalog=None)}
    assert extract_table_references(
        "SELECT 1 FROM some_catalog.some_schema.some_table", "trino") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }

    # with identifier quotes
    assert extract_table_references(
        "SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`",
        "mysql") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }
    assert extract_table_references(
        'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino") == {
            Table(table="some_table",
                  schema="some_schema",
                  catalog="some_catalog")
        }

    assert extract_table_references(
        "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
        "trino",
    ) == {
        Table(table="some_table", schema=None, catalog=None),
        Table(table="other_table", schema=None, catalog=None),
    }

    # test falling back to sqlparse
    logger = mocker.patch("superset.sql_parse.logger")
    sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
    assert extract_table_references(
        sql,
        "trino",
    ) == {Table(table="other_table", schema=None, catalog=None)}
    logger.warning.assert_called_once()

    logger = mocker.patch("superset.migrations.shared.utils.logger")
    sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
    assert extract_table_references(sql, "trino", show_warning=False) == {
        Table(table="other_table", schema=None, catalog=None)
    }
    logger.warning.assert_not_called()
def test_extract_tables_semi_join() -> None:
    """
    Test ``LEFT SEMI JOIN``.
    """
    assert (extract_tables("""
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
    SELECT
        CAST((b.year) as VARCHAR) date,
        name
    FROM right_table
) b
ON a.data = b.date
""") == {Table("left_table"), Table("right_table")})
def test_extract_tables_nested_select() -> None:
    """
    Test that the parser handles selects inside functions.
    """
    assert (extract_tables("""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""") == {Table("COLUMNS", "INFORMATION_SCHEMA")})

    assert (extract_tables("""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
""") == {Table("COLUMNS", "INFORMATION_SCHEMA")})
示例#9
0
 def test_nested_selects(self):
     query = """
         select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
         from INFORMATION_SCHEMA.COLUMNS
         WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
     """
     self.assertEqual({Table("COLUMNS", "INFORMATION_SCHEMA")},
                      self.extract_tables(query))
     query = """
         select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
         from INFORMATION_SCHEMA.COLUMNS
         WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
     """
     self.assertEqual({Table("COLUMNS", "INFORMATION_SCHEMA")},
                      self.extract_tables(query))
def test_extract_tables_with_catalog() -> None:
    """
    Test that catalogs are parsed correctly.
    """
    assert extract_tables("SELECT * FROM catalogname.schemaname.tbname") == {
        Table("tbname", "schemaname", "catalogname")
    }
示例#11
0
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
    config = app.config.copy()
    app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
    mock_upload_to_s3.return_value = "mock-location"
    mock_g.user = True
    mock_database = mock.MagicMock()
    mock_database.get_df.return_value.empty = False
    mock_execute = mock.MagicMock(return_value=True)
    mock_database.get_sqla_engine.return_value.execute = mock_execute
    table_name = "foobar"
    schema = "schema"

    with app.app_context():
        HiveEngineSpec.df_to_sql(
            mock_database,
            Table(table=table_name, schema=schema),
            pd.DataFrame(),
            {
                "if_exists": "replace",
                "header": 1,
                "na_values": "mock",
                "sep": "mock"
            },
        )

    mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
    app.config = config
示例#12
0
def test_df_to_sql_if_exists_fail(mock_g):
    mock_g.user = True
    mock_database = mock.MagicMock()
    mock_database.get_df.return_value.empty = False
    with pytest.raises(SupersetException, match="Table already exists"):
        HiveEngineSpec.df_to_sql(mock_database, Table("foobar"),
                                 pd.DataFrame(), {"if_exists": "fail"})
示例#13
0
def test_upload_to_s3_no_bucket_path():
    with pytest.raises(
            Exception,
            match=
            "No upload bucket specified. You can specify one in the config file.",
    ):
        upload_to_s3("filename", "prefix", Table("table"))
示例#14
0
 def test_reusing_aliases(self):
     query = """
         with q1 as ( select key from q2 where key = '5'),
         q2 as ( select key from src where key = '5')
         select * from (select key from q1) a;
     """
     self.assertEqual({Table("src")}, self.extract_tables(query))
示例#15
0
 def wraps(
     self: BaseSupersetModelRestApi,
     pk: int,
     table_name: str,
     schema_name: Optional[str] = None,
 ) -> Any:
     schema_name_parsed = parse_js_uri_path_item(schema_name,
                                                 eval_undefined=True)
     table_name_parsed = parse_js_uri_path_item(table_name)
     if not table_name_parsed:
         return self.response_422(message=_("Table name undefined"))
     database: Database = self.datamodel.get(pk)
     if not database:
         self.stats_logger.incr(
             f"database_not_found_{self.__class__.__name__}.select_star")
         return self.response_404()
     if not self.appbuilder.sm.can_access_table(
             database, Table(table_name_parsed, schema_name_parsed)):
         self.stats_logger.incr(
             f"permisssion_denied_{self.__class__.__name__}.select_star")
         logger.warning(
             "Permission denied for user %s on table: %s schema: %s",
             g.user,
             table_name_parsed,
             schema_name_parsed,
         )
         return self.response_404()
     return f(self, database, table_name_parsed, schema_name_parsed)
def test_extract_tables_with_schema() -> None:
    """
    Test that schemas are parsed correctly.
    """
    assert extract_tables("SELECT * FROM schemaname.tbname") == {
        Table("tbname", "schemaname")
    }
    assert extract_tables('SELECT * FROM "schemaname"."tbname"') == {
        Table("tbname", "schemaname")
    }
    assert extract_tables('SELECT * FROM "schemaname"."tbname" foo') == {
        Table("tbname", "schemaname")
    }
    assert extract_tables('SELECT * FROM "schemaname"."tbname" AS foo') == {
        Table("tbname", "schemaname")
    }
示例#17
0
    def _get_table(tlist: TokenList) -> Optional[Table]:
        """
        Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
        construct.

        :param tlist: The SQL tokens
        :returns: The table if the name conforms
        """

        # Strip the alias if present.
        idx = len(tlist.tokens)

        if tlist.has_alias():
            ws_idx, _ = tlist.token_next_by(t=Whitespace)

            if ws_idx != -1:
                idx = ws_idx

        tokens = tlist.tokens[:idx]

        odd_token_number = len(tokens) in (1, 3, 5)
        qualified_name_parts = all(
            imt(token, t=[Name, String]) for token in tokens[::2])
        dot_separators = all(
            imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
        if odd_token_number and qualified_name_parts and dot_separators:
            return Table(
                *[remove_quotes(token.value) for token in tokens[::-2]])

        return None
 def wraps(self,
           pk: int,
           table_name: str,
           schema_name: Optional[str] = None):
     schema_name_parsed = parse_js_uri_path_item(schema_name,
                                                 eval_undefined=True)
     table_name_parsed = parse_js_uri_path_item(table_name)
     if not table_name_parsed:
         return self.response_422(message=_("Table name undefined"))
     database: Database = self.datamodel.get(pk)
     if not database:
         self.stats_logger.incr(
             f"database_not_found_{self.__class__.__name__}.select_star")
         return self.response_404()
     # Check that the user can access the datasource
     if not self.appbuilder.sm.can_access_datasource(
             database, Table(table_name_parsed, schema_name_parsed),
             schema_name_parsed):
         self.stats_logger.incr(
             f"permisssion_denied_{self.__class__.__name__}.select_star")
         logger.warning(
             f"Permission denied for user {g.user} on table: {table_name_parsed} "
             f"schema: {schema_name_parsed}")
         return self.response_404()
     return f(self, database, table_name_parsed, schema_name_parsed)
def test_extract_tables_select_if() -> None:
    """
    Test that queries with an ``IF`` work as expected.
    """
    assert (extract_tables("""
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
""") == {Table("t1")})
示例#20
0
def test_upload_to_s3_client_error(client):
    from botocore.exceptions import ClientError

    client.return_value.upload_file.side_effect = ClientError({"Error": {}},
                                                              "operation_name")

    with pytest.raises(ClientError):
        upload_to_s3("filename", "prefix", Table("table"))
示例#21
0
def test_df_to_csv() -> None:
    with pytest.raises(SupersetException):
        HiveEngineSpec.df_to_sql(
            mock.MagicMock(),
            Table("foobar"),
            pd.DataFrame(),
            {"if_exists": "append"},
        )
示例#22
0
def test_extract_tables_keyword() -> None:
    """
    Test that table names that are keywords work as expected.

    If the table name is a ``sqlparse`` reserved keyword (eg, "table_name") the parser
    needs extra logic to identify it.
    """
    assert extract_tables("SELECT * FROM table_name") == {Table("table_name")}
    assert extract_tables("SELECT * FROM table_name AS foo") == {
        Table("table_name")
    }

    # these 3 are considered keywords
    assert extract_tables(
        "SELECT * FROM catalog_name.schema_name.table_name") == {
            Table("table_name", "schema_name", "catalog_name")
        }
def test_extract_tables_select_array() -> None:
    """
    Test that queries selecting arrays work as expected.
    """
    assert (extract_tables("""
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
""") == {Table("t1")})
def test_extract_tables_show_partitions() -> None:
    """
    Test ``SHOW PARTITIONS``.
    """
    assert (extract_tables("""
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
""") == {Table("orders")})
示例#25
0
 def test_identifier_list_with_keyword_as_alias(self):
     query = """
     WITH
         f AS (SELECT * FROM foo),
         match AS (SELECT * FROM f)
     SELECT * FROM match
     """
     self.assertEqual({Table("foo")}, self.extract_tables(query))
    def test_guest_token_does_not_grant_access_to_underlying_table(self):
        sqla_table = self.dash.slices[0].table
        table = Table(table=sqla_table.table_name)

        g.user = self.authorized_guest

        with self.assertRaises(Exception):
            security_manager.raise_for_access(table=table, database=sqla_table.database)
def test_extract_tables_reusing_aliases() -> None:
    """
    Test that the parser follows aliases.
    """
    assert (extract_tables("""
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
""") == {Table("src")})
示例#28
0
def test_create_table_from_csv_if_exists_fail(mock_table, mock_g):
    mock_table.infer.return_value = {}
    mock_g.user = True
    mock_database = mock.MagicMock()
    mock_database.get_df.return_value.empty = False
    with pytest.raises(SupersetException, match="Table already exists"):
        HiveEngineSpec.create_table_from_csv("foo.csv", Table("foobar"),
                                             mock_database, {},
                                             {"if_exists": "fail"})
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
    """
    Test that aliases that are keywords are parsed correctly.
    """
    assert (extract_tables("""
WITH
    f AS (SELECT * FROM foo),
    match AS (SELECT * FROM f)
SELECT * FROM match
""") == {Table("foo")})
示例#30
0
 def test_create_table_from_csv_append(self) -> None:
     self.assertRaises(
         SupersetException,
         HiveEngineSpec.create_table_from_csv,
         "foo.csv",
         Table("foobar"),
         None,
         {},
         {"if_exists": "append"},
     )