Пример #1
0
 def _flatten_sqlparse(self):
     for token in self.sqlparse_tokens:
         # sqlparse returns mysql digit starting identifiers as group
         # check https://github.com/andialbrecht/sqlparse/issues/337
         is_grouped_mysql_digit_name = (
             token.is_group
             and len(token.tokens) == 2
             and token.tokens[0].ttype is Number.Integer
             and (
                 token.tokens[1].is_group and token.tokens[1].tokens[0].ttype is Name
             )
         )
         if token.is_group and not is_grouped_mysql_digit_name:
             yield from token.flatten()
         elif is_grouped_mysql_digit_name:
             # we have digit starting name
             new_tok = Token(
                 value=f"{token.tokens[0].normalized}"
                 f"{token.tokens[1].tokens[0].normalized}",
                 ttype=token.tokens[1].tokens[0].ttype,
             )
             new_tok.parent = token.parent
             yield new_tok
             if len(token.tokens[1].tokens) > 1:
                 # unfortunately there might be nested groups
                 remaining_tokens = token.tokens[1].tokens[1:]
                 for tok in remaining_tokens:
                     if tok.is_group:
                         yield from tok.flatten()
                     else:
                         yield tok
         else:
             yield token
Пример #2
0
def sql_recursively_simplify(node, hide_columns=True):
    # Erase which fields are being updated in an UPDATE
    if node.tokens[0].value == "UPDATE":
        i_set = [i for (i, t) in enumerate(node.tokens) if t.value == "SET"][0]
        i_where = [
            i for (i, t) in enumerate(node.tokens)
            if t.is_group and t.tokens[0].value == "WHERE"
        ][0]
        middle = [Token(tokens.Punctuation, " ... ")]
        node.tokens = node.tokens[:i_set + 1] + middle + node.tokens[i_where:]

    # Erase the names of savepoints since they are non-deteriministic
    if hasattr(node, "tokens"):
        # SAVEPOINT x
        if str(node.tokens[0]) == "SAVEPOINT":
            node.tokens[2].tokens[0].value = "`#`"
            return
        # RELEASE SAVEPOINT x
        elif len(node.tokens) >= 3 and node.tokens[2].value == "SAVEPOINT":
            node.tokens[4].tokens[0].value = "`#`"
            return
        # ROLLBACK TO SAVEPOINT X
        token_values = [getattr(t, "value", "") for t in node.tokens]
        if len(node.tokens) == 7 and token_values[:6] == [
                "ROLLBACK",
                " ",
                "TO",
                " ",
                "SAVEPOINT",
                " ",
        ]:
            node.tokens[6].tokens[0].value = "`#`"
            return

    # Erase volatile part of PG cursor name
    if node.tokens[0].value.startswith('"_django_curs_'):
        node.tokens[0].value = '"_django_curs_#"'

    prev_word_token = None

    for token in node.tokens:
        ttype = getattr(token, "ttype", None)

        # Detect IdentifierList tokens within an ORDER BY, GROUP BY or HAVING
        # clauses
        inside_order_group_having = match_keyword(
            prev_word_token, ["ORDER BY", "GROUP BY", "HAVING"])
        replace_columns = not inside_order_group_having and hide_columns

        if isinstance(token, IdentifierList) and replace_columns:
            token.tokens = [Token(tokens.Punctuation, "...")]
        elif hasattr(token, "tokens"):
            sql_recursively_simplify(token, hide_columns=hide_columns)
        elif ttype in sql_deleteable_tokens:
            token.value = "#"
        elif getattr(token, "value", None) == "NULL":
            token.value = "#"

        if not token.is_whitespace:
            prev_word_token = token
Пример #3
0
    def handle_read_operation(
        self,
        idx: int,
        next_token: Token,
        columns: Columns,
        dynamic_tables: Dict[str, Schema],
    ) -> Tuple[Schema, int, Columns]:
        # looking for a sub-query first - it can be the current token or the first son of this token
        extracted = {}
        if isinstance(next_token, Parenthesis):
            # sub query doesnt have an alias name:
            # `... from (<sub_query>) where ...`
            extracted = self.extract_from_subquery(next_token, "anon", columns)

        elif isinstance(next_token, TokenList) and isinstance(
            next_token.token_first(), Parenthesis
        ):
            # the sub query has an alias name:
            # `... from (<sub_query>) as <alias_name> where ...`
            extracted = self.extract_from_subquery(
                next_token.token_first(), next_token.get_alias(), columns
            )

        elif next_token.ttype not in Keyword:
            # no subquery - just parse the source/dest name of the operator
            extracted, left_columns = self.generate_schema(columns, next_token)
            columns = left_columns
            extracted = self.enrich_with_dynamic(extracted, dynamic_tables)

        return extracted, idx, columns
Пример #4
0
def _add_tenant_to_sql(rule_sql):
    prefix = f'/+/{g.tenant_uid}'
    parsed = sqlparse.parse(rule_sql)
    stmt = parsed[0]
    token_dict = {}
    for token in stmt.tokens:
        if isinstance(token, Identifier):
            if re.match(r'^"/.*', token.value):
                # FROM "/productID/deviceID/"
                index = stmt.token_index(token)
                new_value = f'\"{prefix}{token.value[1:]}'
                token_dict[index] = new_value
            else:
                # SELECT getMetadataPropertyValue('/productID/deviceID/','topic') as topic FROM
                new_value = _replace_func_sql(prefix, token.value)
                if new_value:
                    index = stmt.token_index(token)
                    token_dict[index] = new_value
        # SELECT getMetadataPropertyValue('/productID/deviceID/','topic') as topic,* FROM
        if isinstance(token, IdentifierList):
            for index, identifier in enumerate(token.get_identifiers()):
                new_value = _replace_func_sql(prefix, identifier.value)
                if new_value:
                    token.tokens[index] = Token(None, new_value)
    for index, value in token_dict.items():
        token = Token(None, value)
        stmt.tokens[index] = token
    return str(stmt)
Пример #5
0
 def get_value(self, tok: Token):
     if tok.ttype == tokens.Name.Placeholder:
         return self.placeholder_index(tok)
     elif tok.match(tokens.Keyword, 'NULL'):
         return None
     elif tok.match(tokens.Keyword, 'DEFAULT'):
         return 'DEFAULT'
     else:
         raise SQLDecodeError
Пример #6
0
    def _token2op(self,
                  tok: Token,
                  statement: SQLStatement) -> '_Op':
        op = None
        kw = {'statement': statement, 'query': self.query}
        if tok.match(tokens.Keyword, 'AND'):
            op = AndOp(**kw)

        elif tok.match(tokens.Keyword, 'OR'):
            op = OrOp(**kw)

        elif tok.match(tokens.Keyword, 'IN'):
            op = InOp(**kw)

        elif tok.match(tokens.Keyword, 'NOT'):
            if statement.next_token.match(tokens.Keyword, 'IN'):
                op = NotInOp(**kw)
                statement.skip(1)
            else:
                op = NotOp(**kw)

        elif tok.match(tokens.Keyword, 'LIKE'):
            op = LikeOp(**kw)

        elif tok.match(tokens.Keyword, 'iLIKE'):
            op = iLikeOp(**kw)

        elif tok.match(tokens.Keyword, 'BETWEEN'):
            op = BetweenOp(**kw)
            statement.skip(3)

        elif tok.match(tokens.Keyword, 'IS'):
            op = IsOp(**kw)

        elif isinstance(tok, Comparison):
            op = CmpOp(tok, self.query)

        elif isinstance(tok, Parenthesis):
            if (tok[1].match(tokens.Name.Placeholder, '.*', regex=True)
                    or tok[1].match(tokens.Keyword, 'Null')
                    or isinstance(tok[1], IdentifierList)
                    or tok[1].ttype == tokens.DML
            ):
                pass
            else:
                op = ParenthesisOp(SQLStatement(tok), self.query)

        elif tok.match(tokens.Punctuation, (')', '(')):
            pass

        elif isinstance(tok, Identifier):
            pass
        else:
            raise SQLDecodeError

        return op
Пример #7
0
 def token2sql(
     token: Token, query: 'query_module.BaseQuery'
 ) -> U['CountFuncAll', 'CountFuncSingle']:
     try:
         ## FIX: COUNT(DISTINCT COL)
         ## TODO: This just gets the parser through the token, but distinct logic is not actually handled yet.
         if isinstance(token[0], Identifier):
             token.get_parameters()[0]
         else:
             token[0].get_parameters()[0]
     except IndexError:
         return CountFuncAll(token, query)
     else:
         return CountFuncSingle(token, query)
Пример #8
0
 def handle_into(
     self, idx: int, next_token: Token, statement: TokenList
 ) -> Tuple[Schema, int]:
     table_alias = self.get_identifier_name(next_token)
     table_name = self.get_full_name(next_token)
     if isinstance(next_token, Function):
         extracted = self.extract_write_schema(
             next_token.get_parameters(), table_alias, table_alias
         )
     else:
         extracted = self.extract_write_schema(
             [Token(Wildcard, "*")], table_name, table_alias
         )
     return extracted, idx
Пример #9
0
def sql_recursively_simplify(node):
    # Erase which fields are being updated in an UPDATE
    if node.tokens[0].value == 'UPDATE':
        i_set = [i for (i, t) in enumerate(node.tokens) if t.value == 'SET'][0]
        i_where = [
            i for (i, t) in enumerate(node.tokens)
            if _is_group(t) and t.tokens[0].value == 'WHERE'
        ][0]
        middle = [Token(tokens.Punctuation, ' ... ')]
        node.tokens = node.tokens[:i_set + 1] + middle + node.tokens[i_where:]

    # Erase the names of savepoints since they are non-deteriministic
    if hasattr(node, 'tokens'):
        # SAVEPOINT x
        if str(node.tokens[0]) == 'SAVEPOINT':
            node.tokens[2].tokens[0].value = '`#`'
            return
        # RELEASE SAVEPOINT x
        elif len(node.tokens) >= 3 and node.tokens[2].value == 'SAVEPOINT':
            node.tokens[4].tokens[0].value = "`#`"
            return
        # ROLLBACK TO SAVEPOINT X
        token_values = [getattr(t, 'value', '') for t in node.tokens]
        if len(node.tokens) == 7 and token_values[:6] == [
                'ROLLBACK', ' ', 'TO', ' ', 'SAVEPOINT', ' '
        ]:
            node.tokens[6].tokens[0].value = '`#`'
            return

    # Erase volatile part of PG cursor name
    if node.tokens[0].value.startswith('"_django_curs_'):
        node.tokens[0].value = '"_django_curs_#"'

    for i, token in enumerate(node.tokens):
        ttype = getattr(token, 'ttype', None)

        if isinstance(token, IdentifierList):
            token.tokens = [Token(tokens.Punctuation, '...')]
        elif hasattr(token, 'tokens'):
            sql_recursively_simplify(token)
        elif ttype in sql_deleteable_tokens:
            token.value = '#'
        elif ttype == tokens.Whitespace.Newline:
            token.value = ''  # Erase newlines
        elif ttype == tokens.Whitespace:
            token.value = ' '
        elif getattr(token, 'value', None) == 'NULL':
            token.value = '#'
Пример #10
0
    def test_simple_join2(self):
        data = [JoinData(left_key="aa", right_key="ab"), JoinData(left_key="ec", right_key="eb")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ab", column_type=TypeEnum.int),
                                 Column(name="eb", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[Column(name="ec", column_type=TypeEnum.int),
                                 Column(name="f", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        tables = [table_a, table_b, table_c]

        root = SelectNode(tables=tables, annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "ec")]), Identifier(tokens=[Token(None, "f")])])

        root.next = JoinNode(join_list=data, tables=tables)
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue(len(result) > 0)
Пример #11
0
 def handle_into(self, idx: int, next_token: Token,
                 statement: TokenList) -> Tuple[Schema, int]:
     if next_token.is_group:
         # if token is group, get inner function token if exist
         for token in next_token.tokens:
             if isinstance(token, Function):
                 next_token = token
     table_alias = self.get_identifier_name(next_token)
     table_name = self.get_full_name(next_token)
     if isinstance(next_token, Function):
         extracted = self.extract_write_schema(next_token.get_parameters(),
                                               table_alias, table_alias)
     else:
         extracted = self.extract_write_schema([Token(Wildcard, "*")],
                                               table_name, table_alias)
     return extracted, idx
Пример #12
0
def add_table_name(rls: TokenList, table: str) -> None:
    """
    Modify a RLS expression inplace ensuring columns are fully qualified.
    """
    tokens = rls.tokens[:]
    while tokens:
        token = tokens.pop(0)

        if isinstance(token, Identifier) and token.get_parent_name() is None:
            token.tokens = [
                Token(Name, table),
                Token(Punctuation, "."),
                Token(Name, token.get_name()),
            ]
        elif isinstance(token, TokenList):
            tokens.extend(token.tokens)
Пример #13
0
 def tokens2sql(token: Token,
                query: 'query_module.BaseQuery'
                ) -> Iterator[all_token_types]:
     from .functions import SQLFunc
     if isinstance(token, Identifier):
         # Bug fix for sql parse
         if isinstance(token[0], Parenthesis):
             try:
                 int(token[0][1].value)
             except ValueError:
                 yield SQLIdentifier(token[0][1], query)
             else:
                 yield SQLConstIdentifier(token, query)
         elif isinstance(token[0], Function):
             yield SQLFunc.token2sql(token, query)
         else:
             yield SQLIdentifier(token, query)
     elif isinstance(token, Function):
         yield SQLFunc.token2sql(token, query)
     elif isinstance(token, Comparison):
         yield SQLComparison(token, query)
     elif isinstance(token, IdentifierList):
         for tok in token.get_identifiers():
             yield from SQLToken.tokens2sql(tok, query)
     elif isinstance(token, Parenthesis):
         yield SQLPlaceholder(token, query)
     else:
         raise SQLDecodeError(f'Unsupported: {token.value}')
Пример #14
0
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
    """
    Tests for ``get_rls_for_table``.
    """
    candidate = Identifier([Token(Name, "some_table")])
    db = mocker.patch("superset.db")
    dataset = db.session.query().filter().one_or_none()
    dataset.__str__.return_value = "some_table"

    dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
    assert (
        str(get_rls_for_table(candidate, 1, "public"))
        == "some_table.organization_id = 1"
    )

    dataset.get_sqla_row_level_filters.return_value = [
        text("organization_id = 1"),
        text("foo = 'bar'"),
    ]
    assert (
        str(get_rls_for_table(candidate, 1, "public"))
        == "some_table.organization_id = 1 AND some_table.foo = 'bar'"
    )

    dataset.get_sqla_row_level_filters.return_value = []
    assert get_rls_for_table(candidate, 1, "public") is None
Пример #15
0
 def process(self, stack, stream):
     splitlevel = 0
     stmt = None
     consume_ws = False
     stmt_tokens = []
     for ttype, value in stream:
         # Before appending the token
         if (consume_ws and ttype is not T.Whitespace
             and ttype is not T.Comment.Single):
             consume_ws = False
             stmt.tokens = stmt_tokens
             yield stmt
             self._reset()
             stmt = None
             splitlevel = 0
         if stmt is None:
             stmt = Statement()
             stmt_tokens = []
         splitlevel += self._change_splitlevel(ttype, value)
         # Append the token
         stmt_tokens.append(Token(ttype, value))
         # After appending the token
         if (splitlevel <= 0 and ttype is T.Punctuation
             and value == ';'):
             consume_ws = True
     if stmt is not None:
         stmt.tokens = stmt_tokens
         yield stmt
Пример #16
0
 def handle_select(self, idx: int, next_token: Token) -> Tuple[Columns, int]:
     tokens: List[Token] = (
         next_token.get_identifiers()
         if isinstance(next_token, IdentifierList)
         else [next_token]
     )
     columns = self._extract_columns(tokens)
     return columns, idx
Пример #17
0
    def _extract_from_token(self, token: Token) -> None:
        """
        <Identifier> store a list of subtokens and <IdentifierList> store lists of
        subtoken list.

        It extracts <IdentifierList> and <Identifier> from :param token: and loops
        through all subtokens recursively. It finds table_name_preceding_token and
        passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
        self._tables.

        :param token: instance of Token or child class, e.g. TokenList, to be processed
        """
        if not hasattr(token, "tokens"):
            return

        table_name_preceding_token = False

        # If the table name is a reserved word (eg, "table_name") it won't be returned. We
        # fix this by ensuring that at least one identifier is returned after the FROM
        # before stopping on a keyword.
        has_processed_identifier = False

        for item in token.tokens:
            if item.is_group and (
                not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
            ):
                self._extract_from_token(item)

            if item.ttype in Keyword and (
                item.normalized in PRECEDES_TABLE_NAME
                or item.normalized.endswith(" JOIN")
            ):
                table_name_preceding_token = True
                continue

            # If we haven't processed any identifiers it means the table name is a
            # reserved keyword (eg, "table_name") and we shouldn't skip it.
            if item.ttype in Keyword and has_processed_identifier:
                table_name_preceding_token = False
                continue
            if table_name_preceding_token:
                if isinstance(item, Identifier):
                    self._process_tokenlist(item)
                    has_processed_identifier = True
                elif isinstance(item, IdentifierList):
                    for token2 in item.get_identifiers():
                        if isinstance(token2, TokenList):
                            self._process_tokenlist(token2)
                    has_processed_identifier = True
                elif item.ttype in Keyword:
                    # convert into an identifier
                    fixed = Identifier([Token(Name, item.value)])
                    self._process_tokenlist(fixed)
                    has_processed_identifier = True
            elif isinstance(item, IdentifierList):
                if any(not self._is_identifier(token2) for token2 in item.tokens):
                    self._extract_from_token(item)
Пример #18
0
 def __parse_from__(self):
     from_tables = [
         Identifier(tokens=[Token(value=f.variable_table_name, ttype="")])
         for f in self.db_driver.perform_select_from()
     ]
     from_node = FromNode(tables=self.tables)
     from_node.set_identifier_list(from_tables)
     last = self.root.get_last_node()
     last.next = from_node
     last.next.prev = last
Пример #19
0
    def test_simple_select(self):
        select_node = SelectNode(tables=[self.table_a], annotation_name="demo")
        select_node.from_tables = [
            Identifier(tokens=[Token("int", "A")]),
        ]

        code = select_node.to_code(self.table_a.get_root())
        print(code)
        self.assert_content_in_arr(code, "Relation a(a_ri, a_ai);")
        self.assert_content_in_arr(code, "CLIENT")
        self.assert_content_in_arr(code, "{ Relation::INT,Relation::STRING }")
Пример #20
0
    def test_select2(self):
        table_a = PythonFreeConnexTable(table_name="A",
                                        columns=[
                                            Column(name="a",
                                                   column_type=TypeEnum.int),
                                            Column(name="b",
                                                   column_type=TypeEnum.string)
                                        ],
                                        owner=CharacterEnum.client,
                                        data_sizes=[100],
                                        data_paths=[""],
                                        annotations=[])

        table_b = PythonFreeConnexTable(table_name="B",
                                        columns=[
                                            Column(name="c",
                                                   column_type=TypeEnum.int),
                                            Column(name="d",
                                                   column_type=TypeEnum.string)
                                        ],
                                        owner=CharacterEnum.server,
                                        data_sizes=[100],
                                        data_paths=[""],
                                        annotations=[])

        node = SelectNodePython(tables=[table_a, table_b],
                                annotation_name="demo")
        node.from_tables = [
            Identifier(tokens=[Token("int", "A")]),
            Identifier(tokens=[Token("int", "B")]),
        ]
        relations = node.to_code(table_a.get_root(), should_load_data=False)
        self.assertEqual(len(relations), 2)

        table = relations[0]
        self.assertEqual(table.variable_table_name, "a")
        self.assertIsNotNone(table.relation)

        table = relations[1]
        self.assertEqual(table.variable_table_name, "b")
        self.assertIsNotNone(table.relation)
Пример #21
0
    def add_pidal(self, value: int):
        if getattr(self, "pidal_c", None):
            return
        self.pidal_c = value
        for i in self.table_f.tokens:
            if isinstance(i, Parenthesis):
                for j in i.tokens:
                    if isinstance(j, IdentifierList):
                        j.insert_after(len(j.tokens),
                                       Token(token.Punctuation, ","))
                        j.insert_after(len(j.tokens),
                                       sqlparse.parse("pidal_c")[0].tokens[0])

        for i in self.values.tokens:
            if isinstance(i, Parenthesis):
                for j in i.tokens:
                    if isinstance(j, IdentifierList):
                        j.insert_after(len(j.tokens),
                                       Token(token.Punctuation, ","))
                        j.insert_after(len(j.tokens),
                                       Token(token.Number.Integer, value))
Пример #22
0
    def token2sql(
            token: Token,
            query: 'query_module.BaseQuery') -> U['CountFunc', 'SimpleFunc']:
        ## FIX: agg(distinct)
        token.is_agg_distinct = isinstance(token[0],
                                           Identifier) and token[1][1].match(
                                               tokens.Keyword, 'DISTINCT')

        func = token[0].get_name()
        if func == 'COUNT':
            return CountFunc.token2sql(token, query)
        else:
            return SimpleFunc(token, query)
Пример #23
0
def matches_table_name(candidate: Token, table: str) -> bool:
    """
    Returns if the token represents a reference to the table.

    Tables can be fully qualified with periods.

    Note that in theory a table should be represented as an identifier, but due to
    sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
    classified as a keyword.
    """
    if not isinstance(candidate, Identifier):
        candidate = Identifier([Token(Name, candidate.value)])

    target = sqlparse.parse(table)[0].tokens[0]
    if not isinstance(target, Identifier):
        target = Identifier([Token(Name, target.value)])

    # match from right to left, splitting on the period, eg, schema.table == table
    for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
        if left.value != right.value:
            return False

    return True
Пример #24
0
    def build_query(self, sql, limit=2):
        """
        Parses sql query and changes LIMIT statement value.
        """
        import re
        from sqlparse import parse, tokens
        from sqlparse.sql import Token

        # It's important to have a whitespace right after every LIMIT
        pattern = re.compile('limit([^ ])', re.IGNORECASE)
        sql = pattern.sub(r'LIMIT \1', sql)

        query = parse(sql.rstrip(';'))[0]

        # Find LIMIT statement
        token = query.token_next_match(0, tokens.Keyword, 'LIMIT')
        if token:
            # Find and replace LIMIT value
            value = query.token_next(query.token_index(token), skip_ws=True)
            if value:
                new_token = Token(value.ttype, str(limit))
                query.tokens[query.token_index(value)] = new_token
        else:
            # If limit is not found, append one
            new_tokens = [
                Token(tokens.Whitespace, ' '),
                Token(tokens.Keyword, 'LIMIT'),
                Token(tokens.Whitespace, ' '),
                Token(tokens.Number, str(limit)),
            ]
            last_token = query.tokens[-1]
            if last_token.ttype == tokens.Punctuation:
                query.tokens.remove(last_token)
            for new_token in new_tokens:
                query.tokens.append(new_token)

        return str(query)
Пример #25
0
def _execute_copy_from_local_sql(sql_tree, cursor):
    # Search for 'LOCAL' keyword
    for i, token in enumerate(sql_tree.tokens):
        if token.is_keyword and token.value.lower() == 'local':
            break

    file_path = sql_tree.tokens[i + 2].value.strip('\'"')

    # Replace "LOCAL <file_path>" with "stdin"
    sql_tree.tokens = sql_tree.tokens[0:i] + [Token(_Token.Keyword, 'stdin')
                                              ] + sql_tree.tokens[i + 3:]
    new_sql = sql_tree.to_unicode()

    cursor.flush_to_query_ready()
    with open(file_path, 'rb') as f:
        cursor.copy(new_sql, f)

    cursor.flush_to_query_ready()
Пример #26
0
def get_rls_for_table(
    candidate: Token,
    database_id: int,
    default_schema: Optional[str],
    username: Optional[str] = None,
) -> Optional[TokenList]:
    """
    Given a table name, return any associated RLS predicates.
    """
    # pylint: disable=import-outside-toplevel
    from superset import db
    from superset.connectors.sqla.models import SqlaTable

    if not isinstance(candidate, Identifier):
        candidate = Identifier([Token(Name, candidate.value)])

    table = ParsedQuery.get_table(candidate)
    if not table:
        return None

    dataset = (db.session.query(SqlaTable).filter(
        and_(
            SqlaTable.database_id == database_id,
            SqlaTable.schema == (table.schema or default_schema),
            SqlaTable.table_name == table.table,
        )).one_or_none())
    if not dataset:
        return None

    template_processor = dataset.get_template_processor()
    predicate = " AND ".join(
        str(filter_) for filter_ in dataset.get_sqla_row_level_filters(
            template_processor, username))
    if not predicate:
        return None

    rls = sqlparse.parse(predicate)[0]
    add_table_name(rls, str(dataset))

    return rls
Пример #27
0
    def process(self, stack, stream):
        "Process the stream"
        consume_ws = False
        splitlevel = 0
        stmt = None
        stmt_tokens = []

        # Run over all stream tokens
        for ttype, value in stream:
            # Yield token if we finished a statement and there's no whitespaces
            if consume_ws and ttype not in (T.Whitespace, T.Comment.Single):
                stmt.tokens = stmt_tokens
                yield stmt

                # Reset filter and prepare to process next statement
                self._reset()
                consume_ws = False
                splitlevel = 0
                stmt = None

            # Create a new statement if we are not currently in one of them
            if stmt is None:
                stmt = Statement()
                stmt_tokens = []

            # Change current split level (increase, decrease or remain equal)
            splitlevel += self._change_splitlevel(ttype, value)

            # Append the token to the current statement
            stmt_tokens.append(Token(ttype, value))

            # Check if we get the end of a statement
            if splitlevel <= 0 and ttype is T.Punctuation and value == ';':
                consume_ws = True

        # Yield pending statement (if any)
        if stmt is not None:
            stmt.tokens = stmt_tokens
            yield stmt
Пример #28
0
    def test_simple_join1(self):
        data = [JoinData(left_key="aa", right_key="ba")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ba", column_type=TypeEnum.int),
                                 Column(name="e", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        root = SelectNode(tables=[table_a, table_b], annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "b")]), Identifier(tokens=[Token(None, "c")])])

        root.next = JoinNode(join_list=data, tables=[table_a, table_b])
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue('a.Aggregate({ "aa" });' in result[0])
Пример #29
0
    def handle_update(
        self, idx: int, next_token: Token, statement: TokenList
    ) -> Tuple[Schema, int]:
        table_alias = next_token.get_name()
        table_name = next_token.normalized

        idx, token = self._next_non_empty_token(idx, statement)
        operation_name = token.value.upper()

        extracted = {}
        if operation_name == "SET":
            idx, token = self._next_non_empty_token(idx, statement)

            if isinstance(token, Comparison):
                # This is case when token is Comparison object due to single SET condition in SQL
                extracted = self.extract_write_schema([token], table_name, table_alias)
            else:
                extracted = self.extract_write_schema(
                    token.get_identifiers(), table_name, table_alias
                )

        return extracted, idx
Пример #30
0
    def _token2op(self, tok: Token, statement: SQLStatement) -> '_Op':
        op = None
        kw = {'statement': statement, 'query': self.query}
        if tok.match(tokens.Name.Placeholder, '%(.+)s', regex=True):
            return op

        if tok.match(tokens.Keyword, 'AND'):
            op = AndOp(**kw)

        elif tok.match(tokens.Keyword, 'OR'):
            op = OrOp(**kw)

        elif any(
                t.match(tokens.Comparison, 'IN')
                for t in [tok, *getattr(tok, 'tokens', [])]):
            op = InOp(**kw, token='current_token')

        elif tok.match(tokens.Keyword, 'NOT'):
            if statement.next_token.match(tokens.Keyword, 'IN'):
                op = NotInOp(**kw)
                statement.skip(1)
            else:
                op = NotOp(**kw)
        elif tok.value.endswith("REGEXP"):
            op = RegexpOp(**kw)

        elif isinstance(tok, Comparison) and 'LIKE' in tok.normalized:
            op = LikeOp(**kw)

        elif isinstance(tok, Comparison) and 'iLIKE' in tok.normalized:
            op = iLikeOp(**kw)

        elif tok.match(tokens.Keyword, 'BETWEEN'):
            op = BetweenOp(**kw)
            statement.skip(3)

        elif tok.match(tokens.Keyword, 'IS'):
            op = IsOp(**kw)

        elif tok.value in JSON_OPERATORS:
            op = JSONOp(**kw)

        elif isinstance(tok, Comparison):
            op = CmpOp(tok, self.query)

        elif isinstance(tok, Parenthesis):
            if (tok[1].match(tokens.Name.Placeholder, '.*', regex=True)
                    or tok[1].match(tokens.Keyword, 'Null')
                    or isinstance(tok[1], IdentifierList)
                    or tok[1].ttype == tokens.DML):
                pass
            else:
                op = ParenthesisOp(SQLStatement(tok), self.query)

        elif tok.match(tokens.Punctuation, (')', '(')):
            pass

        elif isinstance(tok, Identifier):
            t = statement.next_token
            if not t or t.match(tokens.Punctuation,
                                (')', '(')) or t.match(tokens.Keyword,
                                                       ('AND', 'OR')):
                op = ColOp(tok, self.query)
        else:
            raise SQLDecodeError

        return op