Exemple #1
0
    def recurse(self, tokens: TokenList) -> SqlMeta:
        in_tables, out_tables = set(), set()
        idx, token = tokens.token_next_by(t=T.Keyword)
        while token:

            # Main parser switch
            if self.is_cte(token):
                cte_name, cte_intables = self.parse_cte(idx, tokens)
                for intable in cte_intables:
                    if intable not in self.ctes:
                        in_tables.add(intable)
            elif _is_in_table(token):
                idx, extracted_tables = _get_tables(tokens, idx,
                                                    self.default_schema)
                for table in extracted_tables:
                    if table.name not in self.ctes:
                        in_tables.add(table)
            elif _is_out_table(token):
                idx, extracted_tables = _get_tables(tokens, idx,
                                                    self.default_schema)
                out_tables.add(
                    extracted_tables[0])  # assuming only one out_table

            idx, token = tokens.token_next_by(t=T.Keyword, idx=idx)

        return SqlMeta(list(in_tables), list(out_tables))
Exemple #2
0
    def parse_cte(self, idx, tokens: TokenList):
        gidx, group = tokens.token_next(idx, skip_ws=True, skip_cm=True)

        # handle recursive keyword
        if group.match(T.Keyword, values=['RECURSIVE']):
            gidx, group = tokens.token_next(gidx, skip_ws=True, skip_cm=True)

        if not group.is_group:
            return [], None

        # get CTE name
        offset = 1
        cte_name = group.token_first(skip_ws=True, skip_cm=True)
        self.ctes.add(cte_name.value)

        # AS keyword
        offset, as_keyword = group.token_next(offset,
                                              skip_ws=True,
                                              skip_cm=True)
        if not as_keyword.match(T.Keyword, values=['AS']):
            raise RuntimeError(f"CTE does not have AS keyword at index {gidx}")

        offset, parens = group.token_next(offset, skip_ws=True, skip_cm=True)
        if isinstance(parens, Parenthesis) or parens.is_group:
            # Parse CTE using recursion.
            return cte_name.value, self.recurse(TokenList(
                parens.tokens)).in_tables
        raise RuntimeError(
            f"Parens {parens} are not Parenthesis at index {gidx}")
    def _get_table(tlist: TokenList) -> Optional[Table]:
        """
        Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
        construct.

        :param tlist: The SQL tokens
        :returns: The table if the name conforms
        """

        # Strip the alias if present.
        idx = len(tlist.tokens)

        if tlist.has_alias():
            ws_idx, _ = tlist.token_next_by(t=Whitespace)

            if ws_idx != -1:
                idx = ws_idx

        tokens = tlist.tokens[:idx]

        if (len(tokens) in (1, 3, 5)
                and all(imt(token, t=[Name, String]) for token in tokens[::2])
                and all(
                    imt(token, m=(Punctuation, "."))
                    for token in tokens[1::2])):
            return Table(
                *[remove_quotes(token.value) for token in tokens[::-2]])

        return None
Exemple #4
0
    def parse(sql: str) -> SqlMeta:
        if sql is None:
            raise ValueError("A sql statement must be provided.")

        # Tokenize the SQL statement
        statements = sqlparse.parse(sql)

        # We assume only one statement in SQL
        tokens = TokenList(statements[0].tokens)
        log.debug(f"Successfully tokenized sql statement: {tokens}")

        in_tables = []
        out_tables = []

        idx, token = tokens.token_next_by(t=T.Keyword)
        while token:
            if _is_in_table(token):
                idx, in_table = _get_table(tokens, idx)
                in_tables.append(in_table)
            elif _is_out_table(token):
                idx, out_table = _get_table(tokens, idx)
                out_tables.append(out_table)

            idx, token = tokens.token_next_by(t=T.Keyword, idx=idx)

        return SqlMeta(in_tables, out_tables)
    def __get_full_name(tlist: TokenList) -> Optional[str]:
        """
        Return the full unquoted table name if valid, i.e., conforms to the following
        [[cluster.]schema.]table construct.

        :param tlist: The SQL tokens
        :returns: The valid full table name
        """

        # Strip the alias if present.
        idx = len(tlist.tokens)

        if tlist.has_alias():
            ws_idx, _ = tlist.token_next_by(t=Whitespace)

            if ws_idx != -1:
                idx = ws_idx

        tokens = tlist.tokens[:idx]

        if (len(tokens) in (1, 3, 5) and all(
                imt(token, t=[Name, String]) for token in tokens[0::2])
                and all(
                    imt(token, m=(Punctuation, "."))
                    for token in tokens[1::2])):
            return ".".join(
                [remove_quotes(token.value) for token in tokens[0::2]])

        return None
Exemple #6
0
 def __vectorize(self, tokenlist):
     token_list = TokenList(list(tokenlist.flatten()))
     # print(token_list.tokens)
     for x in token_list:
         if x.ttype is Comparison:
             idx_comp_op = token_list.token_index(
                 x)  #Index of comparison operator
             attr = token_list.token_prev(
                 idx_comp_op, skip_ws=True,
                 skip_cm=True)[1].value  #Name of the attribute
             print(attr)
             comp_op = x
             # print(comp_op)
             if comp_op.value == '<' or comp_op.value == '<=':
                 lit_dir = 'ub'
             elif comp_op.value == '>' or comp_op.value == '>=':
                 lit_dir = 'lb'
             else:
                 lit_dir = 'bi'
             # print(lit_dir)
             try:
                 lit = float(
                     token_list.token_next(
                         idx_comp_op, skip_ws=True,
                         skip_cm=True)[1].value)  #literal value
             except ValueError:
                 print("Possible join, skipping")
                 continue
             # print(lit)
             if lit_dir == 'bi':
                 self.query_vec['_'.join([attr, 'lb'])] = lit
                 self.query_vec['_'.join([attr, 'ub'])] = lit
                 continue
             self.query_vec['_'.join([attr, lit_dir
                                      ])] = lit  #lit_dir is either lb or ub
Exemple #7
0
 def _handle_target_table_token(self, sub_token: TokenList) -> None:
     if isinstance(sub_token, Function):
         # insert into tab (col1, col2) values (val1, val2); Here tab (col1, col2) will be parsed as Function
         # referring https://github.com/andialbrecht/sqlparse/issues/483 for further information
         if not isinstance(sub_token.token_first(skip_cm=True), Identifier):
             raise SQLLineageException(
                 "An Identifier is expected, got %s[value: %s] instead" %
                 (type(sub_token).__name__, sub_token))
         self._lineage_result.write.add(
             Table.create(sub_token.token_first(skip_cm=True)))
     elif isinstance(sub_token, Comparison):
         # create table tab1 like tab2, tab1 like tab2 will be parsed as Comparison
         # referring https://github.com/andialbrecht/sqlparse/issues/543 for further information
         if not (isinstance(sub_token.left, Identifier)
                 and isinstance(sub_token.right, Identifier)):
             raise SQLLineageException(
                 "An Identifier is expected, got %s[value: %s] instead" %
                 (type(sub_token).__name__, sub_token))
         self._lineage_result.write.add(Table.create(sub_token.left))
         self._lineage_result.read.add(Table.create(sub_token.right))
     else:
         if not isinstance(sub_token, Identifier):
             raise SQLLineageException(
                 "An Identifier is expected, got %s[value: %s] instead" %
                 (type(sub_token).__name__, sub_token))
         self._lineage_result.write.add(Table.create(sub_token))
Exemple #8
0
def test_group_parentheses():
    tokens = [
        Token(T.Keyword, 'CREATE'),
        Token(T.Whitespace, ' '),
        Token(T.Keyword, 'TABLE'),
        Token(T.Whitespace, ' '),
        Token(T.Name, 'table_name'),
        Token(T.Whitespace, ' '),
        Token(T.Punctuation, '('),
        Token(T.Name, 'id'),
        Token(T.Whitespace, ' '),
        Token(T.Keyword, 'SERIAL'),
        Token(T.Whitespace, ' '),
        Token(T.Keyword, 'CHECK'),
        Token(T.Punctuation, '('),
        Token(T.Name, 'id'),
        Token(T.Operator, '='),
        Token(T.Number, '0'),
        Token(T.Punctuation, ')'),
        Token(T.Punctuation, ')'),
        Token(T.Punctuation, ';'),
    ]

    expected_tokens = TokenList([
        Token(T.Keyword, 'CREATE'),
        Token(T.Keyword, 'TABLE'),
        Token(T.Name, 'table_name'),
        Parenthesis([
            Token(T.Punctuation, '('),
            Token(T.Name, 'id'),
            Token(T.Keyword, 'SERIAL'),
            Token(T.Keyword, 'CHECK'),
            Parenthesis([
                Token(T.Punctuation, '('),
                Token(T.Name, 'id'),
                Token(T.Operator, '='),
                Token(T.Number, '0'),
                Token(T.Punctuation, ')'),
            ]),
            Token(T.Punctuation, ')'),
        ]),
        Token(T.Punctuation, ';'),
    ])

    grouped_tokens = group_parentheses(tokens)

    stdout = sys.stdout
    try:
        sys.stdout = StringIO()
        expected_tokens._pprint_tree()
        a = sys.stdout.getvalue()
        sys.stdout = StringIO()
        grouped_tokens._pprint_tree()
        b = sys.stdout.getvalue()
    finally:
        sys.stdout = stdout

    assert_multi_line_equal(a, b)
Exemple #9
0
 def filter_identifier_list(tkn_list: TokenList, token: Token):
     # debug: pprint(token)
     index = tkn_list.token_index(token)
     prev_token: Token = tkn_list.token_prev(index)[1]
     if prev_token is not None:
         # prev is not exist(index: 0) -> None
         if not prev_token.match(DML, 'SELECT'):
             return False
     next_token: Token = tkn_list.token_next(index)[1]
     if next_token is not None:
         # next is not exist(index: list len max) -> None
         if not next_token.match(Keyword, 'FROM'):
             return False
     return True
Exemple #10
0
    def extract_from_column(self):

        '''
        columns_group can collect all tokens between 'DML SELECT' and 'Keyword FROM'
        
        [<DML 'SELECT' at 0x3655A08>, <Whitespace ' ' at 0x3655A68>, <IdentifierList 'me.Sap...' at 0x366E228>,
         <Newline ' ' at 0x3665948>, <Keyword 'FROM' at 0x36659A8>, <Whitespace ' ' at 0x3665A08>,
         <IdentifierList 'SODS2....' at 0x366E390>,
         <Whitespace ' ' at 0x3667228>, <IdentifierList 't,SHAR...' at 0x366E480>, <Newline ' ' at 0x3667528>]
        '''
        
        tokens = self.getTokens()
        tokenlist = TokenList(tokens)
        cols_idx,cols_item = [] , []
        cols_group = []
        '''
            cols_item only keep the columns between select and from.
            Notic : exists many groups if sql have union/union all token , so need use cols_group to collect it.
        '''
        fetch_col_flag = False
        for idx, item in enumerate(tokens):
            before_idx,before_item = tokenlist.token_prev(idx,skip_ws=True)
            next_idx,next_item = tokenlist.token_next(idx,skip_ws=True)
            if not next_item :
                break
            #capture up first column index
            if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \
                (before_item.ttype == Keyword.DML or before_item.value.upper() == 'DISTINCT'):
                cols_idx.append(idx)
                fetch_col_flag = True
                cols_item = []                
            if fetch_col_flag == True:
                
                cols_item.append(item)
            #capture up last column index
            if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \
                next_item.ttype is Keyword and next_item.value.upper() == 'FROM':
                cols_idx.append(idx)
                fetch_col_flag = False
                cols_group.append (''.join([ item.value for item in cols_item]))
        
        '''
        the cols_idx like [[10,12],[24,26]],it's two-dimnsn list , --> flatten to [10,11,12,24,25,26]
        '''
        cols_idxes = sum([list(range(cols_idx[2*i],cols_idx[2*i+1]+1)) for i in range(int(len(cols_idx)/2))],[]) 
        
        keep_tokens = [ item for idx,item in enumerate(tokens) if idx not in cols_idxes ]
        self.tokens = keep_tokens
        self.tokens_val = [item.value for item in tokens]
        return cols_group
Exemple #11
0
    def __process_tokenlist(self, tlist: TokenList):
        # exclude subselects
        if '(' not in str(tlist):
            table_name = self.__get_full_name(tlist)
            if table_name and not table_name.startswith(CTE_PREFIX):
                self._table_names.add(table_name)
            return

        # store aliases
        if tlist.has_alias():
            self._alias_names.add(tlist.get_alias())

        # some aliases are not parsed properly
        if tlist.tokens[0].ttype == Name:
            self._alias_names.add(tlist.tokens[0].value)
        self.__extract_from_token(tlist)
Exemple #12
0
def _define_primary_key(
    metadata: AllFieldMetadata,
    column_definition_group: token_groups.TokenList,
) -> typing.Optional[AllFieldMetadata]:
    idx, constraint_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "CONSTRAINT"))

    idx, primary_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "PRIMARY"), idx=(idx or -1))

    if constraint_keyword is not None and primary_keyword is None:
        raise exceptions.NotSupportedError(
            "When a column definition clause begins with CONSTRAINT, "
            "only a PRIMARY KEY constraint is supported")

    if primary_keyword is None:
        return None

    # If the keyword isn't followed by column name(s), then it's part of
    # a regular column definition and should be handled by _define_column
    if not _contains_column_name(column_definition_group, idx):
        return None

    new_metadata: AllFieldMetadata = deepcopy(metadata)

    while True:
        idx, primary_key_column = column_definition_group.token_next_by(
            t=token_types.Name, idx=idx)

        # 'id' is defined and managed by Fauna, so we ignore any attempts
        # to manage it from SQLAlchemy
        if primary_key_column is None or primary_key_column.value == "id":
            break

        primary_key_column_name = primary_key_column.value

        new_metadata[primary_key_column_name] = {
            **DEFAULT_FIELD_METADATA,  # type: ignore
            **new_metadata.get(primary_key_column_name, {}),  # type: ignore
            "unique":
            True,
            "not_null":
            True,
        }

    return new_metadata
Exemple #13
0
def get_query_tokens(query):
    """
    :type query str
    :rtype: list[sqlparse.sql.Token]
    """
    tokens = TokenList(sqlparse.parse(query)[0].tokens).flatten()
    # print([(token.value, token.ttype) for token in tokens])

    return [token for token in tokens if token.ttype is not Whitespace]
Exemple #14
0
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
    """
    Extract limit clause from SQL statement.

    :param statement: SQL statement
    :return: Limit extracted from query, None if no limit present in statement
    """
    idx, _ = statement.token_next_by(m=(Keyword, "LIMIT"))
    if idx is not None:
        _, token = statement.token_next(idx=idx)
        if token:
            if isinstance(token, IdentifierList):
                # In case of "LIMIT <offset>, <limit>", find comma and extract
                # first succeeding non-whitespace token
                idx, _ = token.token_next_by(m=(sqlparse.tokens.Punctuation, ","))
                _, token = token.token_next(idx=idx)
            if token and token.ttype == sqlparse.tokens.Literal.Number.Integer:
                return int(token.value)
    return None
Exemple #15
0
 def to_clickhouse(cls, schema: str, query: str):
     """
     parse ddl query
     :param schema:
     :param query:
     :return:
     """
     token_list = TokenList()
     parsed = sqlparse.parse(query)[0]
     token_list = cls._add_token(schema, parsed, parsed.tokens, token_list)
     return str(token_list)
Exemple #16
0
def get_query_tokens(query: str) -> List[sqlparse.sql.Token]:
    query = preprocess_query(query)
    parsed = sqlparse.parse(query)

    # handle empty queries (#12)
    if not parsed:
        return []

    tokens = TokenList(parsed[0].tokens).flatten()

    return [token for token in tokens if token.ttype is not Whitespace]
Exemple #17
0
    def parse(cls, sql: str, default_schema: Optional[str] = None) -> SqlMeta:
        if sql is None:
            raise ValueError("A sql statement must be provided.")

        # Tokenize the SQL statement
        statements = sqlparse.parse(sql)

        # We assume only one statement in SQL
        tokens = TokenList(statements[0].tokens)
        log.debug(f"Successfully tokenized sql statement: {tokens}")
        parser = cls(default_schema)
        return parser.recurse(tokens)
    def _process_tokenlist(self, token_list: TokenList):
        """
        Add table names to table set

        :param token_list: TokenList to be processed
        """
        # exclude subselects
        if "(" not in str(token_list):
            table = self._get_table(token_list)
            if table and not table.table.startswith(CTE_PREFIX):
                self._tables.add(table)
            return

        # store aliases
        if token_list.has_alias():
            self._alias_names.add(token_list.get_alias())

        # some aliases are not parsed properly
        if token_list.tokens[0].ttype == Name:
            self._alias_names.add(token_list.tokens[0].value)
        self._extract_from_token(token_list)
Exemple #19
0
 def _process_statement_tokens(cls, statement_tokens, filter_string):
     """
     This function processes the tokens in a statement to ensure that the correct parsing
     behavior occurs. In typical cases, the statement tokens will contain just the comparison -
     in this case, no additional processing occurs. In the case when a filter string contains the
     IN operator, this function parses those tokens into a Comparison object, which will be
     parsed by _get_comparison_for_model_registry.
     :param statement_tokens: List of tokens from a statement
     :param filter_string: Filter string from which the parsed statement tokens originate. Used
     for informative logging
     :return: List of tokens
     """
     expected = "Expected search filter with single comparison operator. e.g. name='myModelName'"
     token_list = []
     if len(statement_tokens) == 0:
         raise MlflowException(
             "Invalid filter '%s'. Could not be parsed. %s" %
             (filter_string, expected),
             error_code=INVALID_PARAMETER_VALUE,
         )
     elif len(statement_tokens) == 1:
         if isinstance(statement_tokens[0], Comparison):
             token_list = statement_tokens
         else:
             raise MlflowException(
                 "Invalid filter '%s'. Could not be parsed. %s" %
                 (filter_string, expected),
                 error_code=INVALID_PARAMETER_VALUE,
             )
     elif len(statement_tokens) > 1:
         comparison_subtokens = []
         for token in statement_tokens:
             if isinstance(token, Comparison):
                 raise MlflowException(
                     "Search filter '%s' contains multiple expressions. "
                     "%s " % (filter_string, expected),
                     error_code=INVALID_PARAMETER_VALUE,
                 )
             elif cls._is_list_component_token(token):
                 comparison_subtokens.append(token)
             elif not token.is_whitespace:
                 break
         # if we have fewer than 3, that means we have an incomplete statement.
         if len(comparison_subtokens) == 3:
             token_list = [Comparison(TokenList(comparison_subtokens))]
         else:
             raise MlflowException(
                 "Invalid filter '%s'. Could not be parsed. %s" %
                 (filter_string, expected),
                 error_code=INVALID_PARAMETER_VALUE,
             )
     return token_list
Exemple #20
0
    def __projections(self, token, tokenlist):
        idx = tokenlist.token_index(token)
        afs_list_idx, afs = tokenlist.token_next(idx,
                                                 skip_ws=True,
                                                 skip_cm=True)
        afs_list = TokenList(list(afs.flatten()))
        for af in afs_list:  # Get AFs

            if af.value.lower() in ['avg', 'count', 'sum', 'min', 'max']:
                # if af not in self.afs_dic:
                #     self.afs_dic[af.value] = []
                af_idx = afs_list.token_index(af)
                punc_idx, _ = afs_list.token_next(af_idx,
                                                  skip_ws=True,
                                                  skip_cm=True)
                attr_idx, attr = afs_list.token_next(punc_idx,
                                                     skip_ws=True,
                                                     skip_cm=True)
                if attr.ttype is not Wildcard:
                    self.afs.append('_'.join([af.value, attr.value]))
                else:
                    self.afs.append(af.value)
Exemple #21
0
    def extract_from_column(self):

        '''
        pick up all tokens between 'DML SELECT' and 'Keyword FROM'
        
        [<DML 'SELECT' at 0x3655A08>, <Whitespace ' ' at 0x3655A68>, <IdentifierList 'me.Sap...' at 0x366E228>,
         <Newline ' ' at 0x3665948>, <Keyword 'FROM' at 0x36659A8>, <Whitespace ' ' at 0x3665A08>,
         <IdentifierList 'SODS2....' at 0x366E390>,
         <Whitespace ' ' at 0x3667228>, <IdentifierList 't,SHAR...' at 0x366E480>, <Newline ' ' at 0x3667528>]
        '''
        
        tokens = self.getTokens()
        tokenlist = TokenList(tokens)
        cols_idx,cols_item = [] , []
        cols_group = []
        fetch_col_flag = False
        for idx, item in enumerate(tokens):
            before_idx,before_item = tokenlist.token_prev(idx,skip_ws=True)
            next_idx,next_item = tokenlist.token_next(idx,skip_ws=True)
            if not next_item :
                break
            #capture up first column index
            if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \
                (before_item.ttype == Keyword.DML or before_item.value.upper() == 'DISTINCT'):
                cols_idx.append(idx)
                fetch_col_flag = True
                cols_item = []                
            if fetch_col_flag == True:
                cols_item.append(item)
            #capture up last column index
            if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \
                next_item.ttype is Keyword and next_item.value.upper() == 'FROM':
                cols_idx.append(idx)
                fetch_col_flag = False
                cols_group.append (cols_item)
        
        cols_idxes = sum([list(range(cols_idx[2*i],cols_idx[2*i+1]+1)) for i in range(int(len(cols_idx)/2))],[]) 
        
        left_tokens = [ item for idx,item in enumerate(tokens) if idx not in cols_idxes ]
Exemple #22
0
def group_parentheses(tokens):
    stack = [[]]
    for token in tokens:
        if token.is_whitespace:
            continue
        if token.match(T.Punctuation, '('):
            stack.append([token])
        else:
            stack[-1].append(token)
            if token.match(T.Punctuation, ')'):
                group = stack.pop()
                stack[-1].append(Parenthesis(group))
    return TokenList(stack[0])
Exemple #23
0
def _define_unique_constraint(
    metadata: AllFieldMetadata,
    column_definition_group: token_groups.TokenList,
) -> typing.Optional[AllFieldMetadata]:
    idx, unique_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "UNIQUE"))

    if unique_keyword is None:
        return None

    # If the keyword isn't followed by column name(s), then it's part of
    # a regular column definition and should be handled by _define_column
    if not _contains_column_name(column_definition_group, idx):
        return None

    new_metadata = deepcopy(metadata)

    while True:
        idx, unique_key_column = column_definition_group.token_next_by(
            t=token_types.Name, idx=idx)

        # 'id' is defined and managed by Fauna, so we ignore any attempts
        # to manage it from SQLAlchemy
        if unique_key_column is None or unique_key_column.value == "id":
            break

        unique_key_column_name = unique_key_column.value

        new_metadata[unique_key_column_name] = {
            **DEFAULT_FIELD_METADATA,  # type: ignore
            **new_metadata.get(unique_key_column_name, {}),  # type: ignore
            "unique":
            True,
        }

    return new_metadata
Exemple #24
0
def get_query_tokens(query: str) -> List[sqlparse.sql.Token]:
    """
    :type query str
    :rtype: list[sqlparse.sql.Token]
    """
    query = preprocess_query(query)
    parsed = sqlparse.parse(query)

    # handle empty queries (#12)
    if not parsed:
        return []

    tokens = TokenList(parsed[0].tokens).flatten()
    # print([(token.value, token.ttype) for token in tokens])

    return [token for token in tokens if token.ttype is not Whitespace]
Exemple #25
0
def _define_column(
    metadata: AllFieldMetadata,
    column_definition_group: token_groups.TokenList,
) -> AllFieldMetadata:
    idx, column = column_definition_group.token_next_by(t=token_types.Name)
    column_name = column.value

    # "id" is auto-generated by Fauna, so we ignore it in SQL column definitions
    if column_name == "id":
        return metadata

    idx, data_type = column_definition_group.token_next_by(t=token_types.Name,
                                                           idx=idx)
    _, not_null_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "NOT NULL"))
    _, unique_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "UNIQUE"))
    _, primary_key_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "PRIMARY KEY"))
    _, default_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "DEFAULT"))
    _, check_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "CHECK"))

    if check_keyword is not None:
        raise exceptions.NotSupportedError("CHECK keyword is not supported.")

    column_metadata: typing.Union[FieldMetadata,
                                  typing.Dict[str, str]] = metadata.get(
                                      column_name, {})
    is_primary_key = primary_key_keyword is not None
    is_not_null = (not_null_keyword is not None or is_primary_key
                   or column_metadata.get("not_null") or False)
    is_unique = (unique_keyword is not None or is_primary_key
                 or column_metadata.get("unique") or False)
    default_value = (default_keyword if default_keyword is None else
                     sql.extract_value(default_keyword.value))

    return {
        **metadata,
        column_name: {
            **DEFAULT_FIELD_METADATA,  # type: ignore
            **metadata.get(column_name, EMPTY_DICT),  # type: ignore
            "unique": is_unique,
            "not_null": is_not_null,
            "default": default_value,
            "type": DATA_TYPE_MAP[data_type.value],
        },
    }
Exemple #26
0
    def tokens(self) -> List[SQLToken]:
        """
        Tokenizes the query
        """
        if self._tokens is not None:
            return self._tokens

        parsed = sqlparse.parse(self.query)
        tokens = []
        # handle empty queries (#12)
        if not parsed:
            return tokens

        sqlparse_tokens = TokenList(parsed[0].tokens).flatten()
        non_empty_tokens = [
            token for token in sqlparse_tokens if token.ttype is not Whitespace
        ]
        last_keyword = None
        for index, tok in enumerate(non_empty_tokens):
            token = SQLToken(
                tok=tok,
                index=index,
                subquery_level=self._subquery_level,
                last_keyword=last_keyword,
            )
            if index > 0:
                # create links between consecutive tokens
                token.previous_token = tokens[index - 1]
                tokens[index - 1].next_token = token

            if token.is_left_parenthesis:
                self._determine_opening_parenthesis_type(token=token)
            elif token.is_right_parenthesis:
                self._determine_closing_parenthesis_type(token=token)

            if tok.is_keyword and tok.normalized not in KEYWORDS_IGNORED:
                last_keyword = tok.normalized
            token.is_in_nested_function = self._is_in_nested_function
            tokens.append(token)

        self._tokens = tokens
        return tokens
Exemple #27
0
    def parse(cls, sql: str, default_schema: Optional[str] = None) -> SqlMeta:
        if sql is None:
            raise ValueError("A sql statement must be provided.")

        # Tokenize the SQL statement
        sql_statements = sqlparse.parse(sql)

        sql_parser = cls(default_schema)
        sql_meta = SqlMeta([], [])

        for sql_statement in sql_statements:
            tokens = TokenList(sql_statement.tokens)
            log.debug(f"Successfully tokenized sql statement: {tokens}")

            result = sql_parser.recurse(tokens)

            # Add the in / out tables (if any) to the sql meta
            sql_meta.add_in_tables(result.in_tables)
            sql_meta.add_out_tables(result.out_tables)

        return sql_meta
Exemple #28
0
 def _handle_source_table_token(self, sub_token: TokenList) -> None:
     if isinstance(sub_token, Identifier):
         if isinstance(sub_token.token_first(skip_cm=True), Parenthesis):
             # SELECT col1 FROM (SELECT col2 FROM tab1) dt, the subquery will be parsed as Identifier
             # and this Identifier's get_real_name method would return alias name dt
             # referring https://github.com/andialbrecht/sqlparse/issues/218 for further information
             pass
         else:
             self._lineage_result.read.add(Table.create(sub_token))
     elif isinstance(sub_token, IdentifierList):
         # This is to support join in ANSI-89 syntax
         for token in sub_token.tokens:
             if isinstance(token, Identifier):
                 self._lineage_result.read.add(Table.create(token))
     elif isinstance(sub_token, Parenthesis):
         # SELECT col1 FROM (SELECT col2 FROM tab1), the subquery will be parsed as Parenthesis
         # This syntax without alias for subquery is invalid in MySQL, while valid for SparkSQL
         pass
     else:
         raise SQLLineageException(
             "An Identifier is expected, got %s[value: %s] instead" %
             (type(sub_token).__name__, sub_token))
Exemple #29
0
def _define_foreign_key_constraint(
    metadata: AllFieldMetadata, column_definition_group: token_groups.TokenList
) -> typing.Optional[AllFieldMetadata]:
    idx, foreign_keyword = column_definition_group.token_next_by(
        m=(token_types.Keyword, "FOREIGN"))
    if foreign_keyword is None:
        return None

    idx, _ = column_definition_group.token_next_by(m=(token_types.Name, "KEY"),
                                                   idx=idx)
    idx, foreign_key_column = column_definition_group.token_next_by(
        t=token_types.Name, idx=idx)
    column_name = foreign_key_column.value

    idx, _ = column_definition_group.token_next_by(m=(token_types.Keyword,
                                                      "REFERENCES"),
                                                   idx=idx)
    idx, reference_table = column_definition_group.token_next_by(
        t=token_types.Name, idx=idx)
    reference_table_name = reference_table.value
    idx, reference_column = column_definition_group.token_next_by(
        t=token_types.Name, idx=idx)
    reference_column_name = reference_column.value

    if any(
            metadata.get(column_name, EMPTY_DICT).get("references",
                                                      EMPTY_DICT)):
        raise exceptions.NotSupportedError(
            "Foreign keys with multiple references are not currently supported."
        )

    if reference_column_name != "id":
        raise exceptions.NotSupportedError(
            "Foreign keys referring to fields other than ID are not currently supported."
        )

    return {
        **metadata,
        column_name: {
            **DEFAULT_FIELD_METADATA,  # type: ignore
            **metadata.get(column_name, EMPTY_DICT),
            "references": {
                reference_table_name: reference_column_name
            },
        },
    }
Exemple #30
0
def convert_expression_to_python(token):
    if not token.is_group:
        if token.value.upper() == 'TRUE':
            return 'sql.true()'
        elif token.value.upper() == 'FALSE':
            return 'sql.false()'
        elif token.ttype == T.Name:
            return 'sql.literal_column({0!r})'.format(str(token.value))
        else:
            return 'sql.text({0!r})'.format(str(token.value))

    if isinstance(token, Parenthesis):
        return '({0})'.format(convert_expression_to_python(TokenList(token.tokens[1:-1])))

    elif len(token.tokens) == 1:
        return convert_expression_to_python(token.tokens[0])

    elif len(token.tokens) == 3 and token.tokens[1].ttype == T.Comparison:
        lhs = convert_expression_to_python(token.tokens[0])
        rhs = convert_expression_to_python(token.tokens[2])
        op = token.tokens[1].value
        if op == '=':
            op = '=='
        return '{0} {1} {2}'.format(lhs, op, rhs)

    elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IN') and isinstance(token.tokens[2], Parenthesis):
        lhs = convert_expression_to_python(token.tokens[0])
        rhs = [convert_expression_to_python(t) for t in token.tokens[2].tokens[1:-1] if not t.match(T.Punctuation, ',')]
        return '{0}.in_({1!r})'.format(lhs, tuple(rhs))

    elif len(token.tokens) == 4 and token.tokens[1].match(T.Comparison, '~') and token.tokens[2].match(T.Name, 'E') and token.tokens[3].ttype == T.String.Single:
        lhs = convert_expression_to_python(token.tokens[0])
        pattern = token.tokens[3].value.replace('\\\\', '\\')
        return 'regexp({0}, {1})'.format(lhs, pattern)

    elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IS') and token.tokens[2].match(T.Keyword, 'NULL'):
        lhs = convert_expression_to_python(token.tokens[0])
        return '{0} == None'.format(lhs)

    elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IS') and token.tokens[2].match(T.Keyword, 'NOT NULL'):
        lhs = convert_expression_to_python(token.tokens[0])
        return '{0} != None'.format(lhs)

    else:
        parts = []
        op = None
        idx = -1

        while True:
            new_idx, op_token = token.token_next_by(m=(T.Keyword, ('AND', 'OR')), idx=idx)
            if op_token is None:
                break
            if op is None:
                op = op_token.normalized
            assert op == op_token.normalized
            new_tokens = token.tokens[idx+1:new_idx]
            if len(new_tokens) == 1:
                parts.append(convert_expression_to_python(new_tokens[0]))
            else:
                parts.append(convert_expression_to_python(TokenList(new_tokens)))
            idx = new_idx + 1

        if idx == -1:
            raise ValueError('unknown expression - {0}'.format(token))

        new_tokens = token.tokens[idx:]
        if len(new_tokens) == 1:
            parts.append(convert_expression_to_python(new_tokens[0]))
        else:
            parts.append(convert_expression_to_python(TokenList(new_tokens)))

        return 'sql.{0}_({1})'.format(op.lower(), ', '.join(parts))