Пример #1
0
 def create(identifier: Identifier):
     schema = (
         Schema(identifier.get_parent_name())
         if identifier.get_parent_name() is not None
         else Schema()
     )
     return Table(identifier.get_real_name(), schema)
Пример #2
0
def replace_recr(stmt, ns):
    try:
        #if it has tokens, it is not a leaf node
        tokens = stmt.tokens

        for n, token in enumerate(tokens):
            stmt.tokens[n] = replace_recr(token, ns)
        return stmt

    #If it does not have tokens, it is a leaf node
    except AttributeError:

        if stmt.ttype is Name.Placeholder:
            name = stmt.to_unicode()
            if name[0] == ":":
                name = var_value(name[1:], ns)
            return Identifier(name)

        if stmt.ttype is Token.Literal.String.Single or Token.Literal.String.Symbol:
            lit = stmt.to_unicode()

            #Ensure that this is a string literal
            if lit[0] == "'" or lit[0] == '"':
                if lit[1] == ":":
                    #Strip quotes and colon before lookup
                    val = var_value(lit[2:-1], ns)
                    #Re-apply quotes
                    lit = lit[0] + val + lit[-1]
            return Identifier(lit)
        #If it is not a placeholder or literal, return unchanged
        return stmt
Пример #3
0
    def test_simple_join2(self):
        data = [JoinData(left_key="aa", right_key="ab"), JoinData(left_key="ec", right_key="eb")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ab", column_type=TypeEnum.int),
                                 Column(name="eb", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[Column(name="ec", column_type=TypeEnum.int),
                                 Column(name="f", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        tables = [table_a, table_b, table_c]

        root = SelectNode(tables=tables, annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "ec")]), Identifier(tokens=[Token(None, "f")])])

        root.next = JoinNode(join_list=data, tables=tables)
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue(len(result) > 0)
Пример #4
0
    def from_identifier(cls,
                        identifier: token_groups.Identifier,
                        position: int = 0) -> Column:
        """Create a column from an SQL identifier token.

        Params:
        -------
        identifier: SQL token with column label.

        Returns:
        --------
        A Column object based on the given identifier token.
        """
        idx, identifier_name = identifier.token_next_by(
            t=token_types.Name, i=token_groups.Function)

        _, maybe_dot = identifier.token_next(idx, skip_ws=True, skip_cm=True)
        if maybe_dot is None or not maybe_dot.match(token_types.Punctuation,
                                                    "."):
            table_name = None
            name = identifier_name.value
        else:
            table_name = identifier_name.value
            idx, column_name_token = identifier.token_next_by(
                t=token_types.Name, idx=idx)
            # Fauna doesn't have an 'id' field, so we extract the ID value from the 'ref' included
            # in query responses, but we still want to map the field name to aliases as with other
            # fields for consistency when passing results to SQLAlchemy
            name = "ref" if column_name_token.value == "id" else column_name_token.value

        idx, as_keyword = identifier.token_next_by(m=(token_types.Keyword,
                                                      "AS"),
                                                   idx=idx)

        if as_keyword is None:
            alias = "id" if name == "ref" else name
        else:
            _, alias_identifier = identifier.token_next_by(
                i=token_groups.Identifier, idx=idx)
            alias = alias_identifier.value

        function_name: typing.Optional[Function] = None
        if re.match(COUNT_REGEX, name):
            function_name = Function.COUNT
        elif re.match(NOT_SUPPORTED_FUNCTION_REGEX, name):
            raise exceptions.NotSupportedError(
                "MIN, MAX, AVG, and SUM functions are not yet supported.")

        column_params: ColumnParams = {
            "table_name": table_name,
            "name": name,
            "alias": alias,
            "function_name": function_name,
            "position": position,
        }

        return Column(**column_params)
Пример #5
0
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
    """
    Tests for ``get_rls_for_table``.
    """
    candidate = Identifier([Token(Name, "some_table")])
    db = mocker.patch("superset.db")
    dataset = db.session.query().filter().one_or_none()
    dataset.__str__.return_value = "some_table"

    dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
    assert (
        str(get_rls_for_table(candidate, 1, "public"))
        == "some_table.organization_id = 1"
    )

    dataset.get_sqla_row_level_filters.return_value = [
        text("organization_id = 1"),
        text("foo = 'bar'"),
    ]
    assert (
        str(get_rls_for_table(candidate, 1, "public"))
        == "some_table.organization_id = 1 AND some_table.foo = 'bar'"
    )

    dataset.get_sqla_row_level_filters.return_value = []
    assert get_rls_for_table(candidate, 1, "public") is None
Пример #6
0
    def parse_ident(ident: Identifier) -> str:
        # Extract table name from possible schema.table naming
        token_list = ident.flatten()
        table_name = next(token_list).value
        try:
            # Determine if the table contains the schema
            # separated by a dot (format: 'schema.table')
            dot = next(token_list)
            if dot.match(Punctuation, '.'):
                table_name += dot.value
                table_name += next(token_list).value

                # And again, to match bigquery's 'database.schema.table'
                try:
                    dot = next(token_list)
                    if dot.match(Punctuation, '.'):
                        table_name += dot.value
                        table_name += next(token_list).value
                except StopIteration:
                    # Do not insert database name if it's not specified
                    pass
            elif default_schema:
                table_name = f'{default_schema}.{table_name}'
        except StopIteration:
            if default_schema:
                table_name = f'{default_schema}.{table_name}'

        table_name = table_name.replace('`', '')
        return table_name
Пример #7
0
 def parse_identifier(item):
     alias = item.get_alias()
     sp_idx = item.token_next_by(t=Whitespace)[0] or len(item.tokens)
     item_rev = Identifier(list(reversed(item.tokens[:sp_idx])))
     name = item_rev._get_first_name(real_name=True)
     alias = alias or name
     dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'))
     if dot_idx is not None:
         schema_name = item_rev._get_first_name(dot_idx, real_name=True)
         dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'),
                                             idx=dot_idx)
         if dot_idx is not None:
             catalog_name = item_rev._get_first_name(dot_idx,
                                                     real_name=True)
         else:
             catalog_name = None
     else:
         schema_name = None
         catalog_name = None
     schema_quoted = schema_name and item.value[0] == '"'
     if schema_name and not schema_quoted:
         schema_name = schema_name.lower()
     quote_count = item.value.count('"')
     name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
     alias_quoted = alias and item.value[-1] == '"'
     if alias_quoted or name_quoted and not alias and name.islower():
         alias = '"' + (alias or name) + '"'
     if name and not name_quoted and not name.islower():
         if not alias:
             alias = name
         name = name.lower()
     return catalog_name, schema_name, name, alias
Пример #8
0
    def _extract_from_token(self, token: Token) -> None:
        """
        <Identifier> store a list of subtokens and <IdentifierList> store lists of
        subtoken list.

        It extracts <IdentifierList> and <Identifier> from :param token: and loops
        through all subtokens recursively. It finds table_name_preceding_token and
        passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
        self._tables.

        :param token: instance of Token or child class, e.g. TokenList, to be processed
        """
        if not hasattr(token, "tokens"):
            return

        table_name_preceding_token = False

        # If the table name is a reserved word (eg, "table_name") it won't be returned. We
        # fix this by ensuring that at least one identifier is returned after the FROM
        # before stopping on a keyword.
        has_processed_identifier = False

        for item in token.tokens:
            if item.is_group and (
                not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
            ):
                self._extract_from_token(item)

            if item.ttype in Keyword and (
                item.normalized in PRECEDES_TABLE_NAME
                or item.normalized.endswith(" JOIN")
            ):
                table_name_preceding_token = True
                continue

            # If we haven't processed any identifiers it means the table name is a
            # reserved keyword (eg, "table_name") and we shouldn't skip it.
            if item.ttype in Keyword and has_processed_identifier:
                table_name_preceding_token = False
                continue
            if table_name_preceding_token:
                if isinstance(item, Identifier):
                    self._process_tokenlist(item)
                    has_processed_identifier = True
                elif isinstance(item, IdentifierList):
                    for token2 in item.get_identifiers():
                        if isinstance(token2, TokenList):
                            self._process_tokenlist(token2)
                    has_processed_identifier = True
                elif item.ttype in Keyword:
                    # convert into an identifier
                    fixed = Identifier([Token(Name, item.value)])
                    self._process_tokenlist(fixed)
                    has_processed_identifier = True
            elif isinstance(item, IdentifierList):
                if any(not self._is_identifier(token2) for token2 in item.tokens):
                    self._extract_from_token(item)
Пример #9
0
    def _extract_direction(
            cls, identifier: token_groups.Identifier) -> OrderDirection:
        _, direction_token = identifier.token_next_by(t=token_types.Keyword)

        if direction_token is None:
            PENULTIMATE_TOKEN = -2
            # For some reason, when ordering by multiple columns with a direction keyword,
            # sqlparse groups the final column with the direction in an Identifier token.
            # There is an open issue (https://github.com/andialbrecht/sqlparse/issues/606),
            # though without any response, so it seems to be a bug.
            _, direction_identifier = identifier.token_next_by(
                i=token_groups.Identifier, idx=PENULTIMATE_TOKEN)
            if direction_identifier is not None:
                _, direction_token = direction_identifier.token_next_by(
                    t=token_types.Keyword)

        return (getattr(OrderDirection, direction_token.value)
                if direction_token else None)
Пример #10
0
 def __parse_from__(self):
     from_tables = [
         Identifier(tokens=[Token(value=f.variable_table_name, ttype="")])
         for f in self.db_driver.perform_select_from()
     ]
     from_node = FromNode(tables=self.tables)
     from_node.set_identifier_list(from_tables)
     last = self.root.get_last_node()
     last.next = from_node
     last.next.prev = last
Пример #11
0
    def test_simple_select(self):
        select_node = SelectNode(tables=[self.table_a], annotation_name="demo")
        select_node.from_tables = [
            Identifier(tokens=[Token("int", "A")]),
        ]

        code = select_node.to_code(self.table_a.get_root())
        print(code)
        self.assert_content_in_arr(code, "Relation a(a_ri, a_ai);")
        self.assert_content_in_arr(code, "CLIENT")
        self.assert_content_in_arr(code, "{ Relation::INT,Relation::STRING }")
Пример #12
0
    def test_select2(self):
        table_a = PythonFreeConnexTable(table_name="A",
                                        columns=[
                                            Column(name="a",
                                                   column_type=TypeEnum.int),
                                            Column(name="b",
                                                   column_type=TypeEnum.string)
                                        ],
                                        owner=CharacterEnum.client,
                                        data_sizes=[100],
                                        data_paths=[""],
                                        annotations=[])

        table_b = PythonFreeConnexTable(table_name="B",
                                        columns=[
                                            Column(name="c",
                                                   column_type=TypeEnum.int),
                                            Column(name="d",
                                                   column_type=TypeEnum.string)
                                        ],
                                        owner=CharacterEnum.server,
                                        data_sizes=[100],
                                        data_paths=[""],
                                        annotations=[])

        node = SelectNodePython(tables=[table_a, table_b],
                                annotation_name="demo")
        node.from_tables = [
            Identifier(tokens=[Token("int", "A")]),
            Identifier(tokens=[Token("int", "B")]),
        ]
        relations = node.to_code(table_a.get_root(), should_load_data=False)
        self.assertEqual(len(relations), 2)

        table = relations[0]
        self.assertEqual(table.variable_table_name, "a")
        self.assertIsNotNone(table.relation)

        table = relations[1]
        self.assertEqual(table.variable_table_name, "b")
        self.assertIsNotNone(table.relation)
Пример #13
0
 def create(identifier: Identifier):
     # rewrite identifier's get_real_name method, by matching the last dot instead of the first dot, so that the
     # real name for a.b.c will be c instead of b
     dot_idx, _ = identifier._token_matching(
         lambda token: imt(token, m=(Punctuation, ".")),
         start=len(identifier.tokens),
         reverse=True,
     )
     real_name = identifier._get_first_name(dot_idx, real_name=True)
     # rewrite identifier's get_parent_name accordingly
     parent_name = (
         "".join(
             [
                 escape_identifier_name(token.value)
                 for token in identifier.tokens[:dot_idx]
             ]
         )
         if dot_idx
         else None
     )
     schema = Schema(parent_name) if parent_name is not None else Schema()
     return Table(real_name, schema)
Пример #14
0
    def test_simple_join1(self):
        data = [JoinData(left_key="aa", right_key="ba")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ba", column_type=TypeEnum.int),
                                 Column(name="e", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        root = SelectNode(tables=[table_a, table_b], annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "b")]), Identifier(tokens=[Token(None, "c")])])

        root.next = JoinNode(join_list=data, tables=[table_a, table_b])
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue('a.Aggregate({ "aa" });' in result[0])
Пример #15
0
def matches_table_name(candidate: Token, table: str) -> bool:
    """
    Returns if the token represents a reference to the table.

    Tables can be fully qualified with periods.

    Note that in theory a table should be represented as an identifier, but due to
    sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
    classified as a keyword.
    """
    if not isinstance(candidate, Identifier):
        candidate = Identifier([Token(Name, candidate.value)])

    target = sqlparse.parse(table)[0].tokens[0]
    if not isinstance(target, Identifier):
        target = Identifier([Token(Name, target.value)])

    # match from right to left, splitting on the period, eg, schema.table == table
    for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
        if left.value != right.value:
            return False

    return True
Пример #16
0
    def from_identifier(cls, identifier: token_groups.Identifier) -> Table:
        """Extract table name from an SQL identifier.

        Params:
        -------
        identifier: SQL token that contains the table's name.

        Returns:
        --------
        A new Table object.
        """
        idx, name = identifier.token_next_by(t=token_types.Name)
        assert name is not None

        idx, _ = identifier.token_next_by(m=(token_types.Keyword, "AS"),
                                          idx=idx)
        if idx is None:
            return cls(name=name.value)

        _, alias = identifier.token_next_by(i=token_groups.Identifier, idx=idx)
        if alias is None:
            return cls(name=name.value)

        return cls(name=name.value, alias=alias.value)
Пример #17
0
 def parse_ident(ident: Identifier) -> str:
     # Extract table name from possible schema.table naming
     token_list = ident.flatten()
     table_name = next(token_list).value
     try:
         # Determine if the table contains the schema
         # separated by a dot (format: 'schema.table')
         dot = next(token_list)
         if dot.match(Punctuation, '.'):
             table_name += dot.value
             table_name += next(token_list).value
         elif default_schema:
             table_name = f'{default_schema}.{table_name}'
     except StopIteration:
         if default_schema:
             table_name = f'{default_schema}.{table_name}'
     return table_name
Пример #18
0
 def get_identifier_parents(self):
     if self.identifier is None:
         return None, None
     item_rev = Identifier(list(reversed(self.identifier.tokens)))
     name = item_rev._get_first_name(real_name = True)
     dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'))
     if dot_idx is not None:
         schema_name = item_rev._get_first_name(dot_idx, real_name = True)
         dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'), idx = dot_idx)
         if dot_idx is not None:
             catalog_name = item_rev._get_first_name(dot_idx, real_name = True)
         else:
             catalog_name = None
     else:
         schema_name = None
         catalog_name = None
     return catalog_name, schema_name
Пример #19
0
def get_rls_for_table(
    candidate: Token,
    database_id: int,
    default_schema: Optional[str],
    username: Optional[str] = None,
) -> Optional[TokenList]:
    """
    Given a table name, return any associated RLS predicates.
    """
    # pylint: disable=import-outside-toplevel
    from superset import db
    from superset.connectors.sqla.models import SqlaTable

    if not isinstance(candidate, Identifier):
        candidate = Identifier([Token(Name, candidate.value)])

    table = ParsedQuery.get_table(candidate)
    if not table:
        return None

    dataset = (db.session.query(SqlaTable).filter(
        and_(
            SqlaTable.database_id == database_id,
            SqlaTable.schema == (table.schema or default_schema),
            SqlaTable.table_name == table.table,
        )).one_or_none())
    if not dataset:
        return None

    template_processor = dataset.get_template_processor()
    predicate = " AND ".join(
        str(filter_) for filter_ in dataset.get_sqla_row_level_filters(
            template_processor, username))
    if not predicate:
        return None

    rls = sqlparse.parse(predicate)[0]
    add_table_name(rls, str(dataset))

    return rls
Пример #20
0
 def parse_identifier(item):
     alias = item.get_alias()
     sp_idx = item.token_next_by(t=Whitespace)[0] or len(item.tokens)
     item_rev = Identifier(list(reversed(item.tokens[:sp_idx])))
     name = item_rev._get_first_name(real_name=True)
     alias = alias or name
     dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'))
     if dot_idx is not None:
         schema_name = item_rev._get_first_name(dot_idx, real_name=True)
         dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'),
                                             idx=dot_idx)
         if dot_idx is not None:
             catalog_name = item_rev._get_first_name(dot_idx,
                                                     real_name=True)
         else:
             catalog_name = None
     else:
         schema_name = None
         catalog_name = None
     # TODO: this business below needs help
     # for one we need to apply this logic to catalog_name
     # then the logic around name_quoted = quote_count > 2 obviously
     # doesn't work.  Finally, quotechar needs to be customized
     schema_quoted = schema_name and item.value[0] == '"'
     if schema_name and not schema_quoted:
         schema_name = schema_name.lower()
     quote_count = item.value.count('"')
     name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
     alias_quoted = alias and item.value[-1] == '"'
     if alias_quoted or name_quoted and not alias and name.islower():
         alias = '"' + (alias or name) + '"'
     if name and not name_quoted and not name.islower():
         if not alias:
             alias = name
         name = name.lower()
     return catalog_name, schema_name, name, alias
 def construct_identifier(self, content):
     return Identifier([Token('', content)])
Пример #22
0
    def check_query(custom_validation_param):
        sql_tokens = sqlparse.parse(custom_validation_param["query_validation"])[0]

        if Statement(sql_tokens).get_type() != "SELECT" or Identifier(sql_tokens).is_wildcard():
            raise InvalidUsage('Not valid query', status_code=400)