def recurse(self, tokens: TokenList) -> SqlMeta: in_tables, out_tables = set(), set() idx, token = tokens.token_next_by(t=T.Keyword) while token: # Main parser switch if self.is_cte(token): cte_name, cte_intables = self.parse_cte(idx, tokens) for intable in cte_intables: if intable not in self.ctes: in_tables.add(intable) elif _is_in_table(token): idx, extracted_tables = _get_tables(tokens, idx, self.default_schema) for table in extracted_tables: if table.name not in self.ctes: in_tables.add(table) elif _is_out_table(token): idx, extracted_tables = _get_tables(tokens, idx, self.default_schema) out_tables.add( extracted_tables[0]) # assuming only one out_table idx, token = tokens.token_next_by(t=T.Keyword, idx=idx) return SqlMeta(list(in_tables), list(out_tables))
def parse_cte(self, idx, tokens: TokenList): gidx, group = tokens.token_next(idx, skip_ws=True, skip_cm=True) # handle recursive keyword if group.match(T.Keyword, values=['RECURSIVE']): gidx, group = tokens.token_next(gidx, skip_ws=True, skip_cm=True) if not group.is_group: return [], None # get CTE name offset = 1 cte_name = group.token_first(skip_ws=True, skip_cm=True) self.ctes.add(cte_name.value) # AS keyword offset, as_keyword = group.token_next(offset, skip_ws=True, skip_cm=True) if not as_keyword.match(T.Keyword, values=['AS']): raise RuntimeError(f"CTE does not have AS keyword at index {gidx}") offset, parens = group.token_next(offset, skip_ws=True, skip_cm=True) if isinstance(parens, Parenthesis) or parens.is_group: # Parse CTE using recursion. return cte_name.value, self.recurse(TokenList( parens.tokens)).in_tables raise RuntimeError( f"Parens {parens} are not Parenthesis at index {gidx}")
def _get_table(tlist: TokenList) -> Optional[Table]: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. :param tlist: The SQL tokens :returns: The table if the name conforms """ # Strip the alias if present. idx = len(tlist.tokens) if tlist.has_alias(): ws_idx, _ = tlist.token_next_by(t=Whitespace) if ws_idx != -1: idx = ws_idx tokens = tlist.tokens[:idx] if (len(tokens) in (1, 3, 5) and all(imt(token, t=[Name, String]) for token in tokens[::2]) and all( imt(token, m=(Punctuation, ".")) for token in tokens[1::2])): return Table( *[remove_quotes(token.value) for token in tokens[::-2]]) return None
def parse(sql: str) -> SqlMeta: if sql is None: raise ValueError("A sql statement must be provided.") # Tokenize the SQL statement statements = sqlparse.parse(sql) # We assume only one statement in SQL tokens = TokenList(statements[0].tokens) log.debug(f"Successfully tokenized sql statement: {tokens}") in_tables = [] out_tables = [] idx, token = tokens.token_next_by(t=T.Keyword) while token: if _is_in_table(token): idx, in_table = _get_table(tokens, idx) in_tables.append(in_table) elif _is_out_table(token): idx, out_table = _get_table(tokens, idx) out_tables.append(out_table) idx, token = tokens.token_next_by(t=T.Keyword, idx=idx) return SqlMeta(in_tables, out_tables)
def __get_full_name(tlist: TokenList) -> Optional[str]: """ Return the full unquoted table name if valid, i.e., conforms to the following [[cluster.]schema.]table construct. :param tlist: The SQL tokens :returns: The valid full table name """ # Strip the alias if present. idx = len(tlist.tokens) if tlist.has_alias(): ws_idx, _ = tlist.token_next_by(t=Whitespace) if ws_idx != -1: idx = ws_idx tokens = tlist.tokens[:idx] if (len(tokens) in (1, 3, 5) and all( imt(token, t=[Name, String]) for token in tokens[0::2]) and all( imt(token, m=(Punctuation, ".")) for token in tokens[1::2])): return ".".join( [remove_quotes(token.value) for token in tokens[0::2]]) return None
def __vectorize(self, tokenlist): token_list = TokenList(list(tokenlist.flatten())) # print(token_list.tokens) for x in token_list: if x.ttype is Comparison: idx_comp_op = token_list.token_index( x) #Index of comparison operator attr = token_list.token_prev( idx_comp_op, skip_ws=True, skip_cm=True)[1].value #Name of the attribute print(attr) comp_op = x # print(comp_op) if comp_op.value == '<' or comp_op.value == '<=': lit_dir = 'ub' elif comp_op.value == '>' or comp_op.value == '>=': lit_dir = 'lb' else: lit_dir = 'bi' # print(lit_dir) try: lit = float( token_list.token_next( idx_comp_op, skip_ws=True, skip_cm=True)[1].value) #literal value except ValueError: print("Possible join, skipping") continue # print(lit) if lit_dir == 'bi': self.query_vec['_'.join([attr, 'lb'])] = lit self.query_vec['_'.join([attr, 'ub'])] = lit continue self.query_vec['_'.join([attr, lit_dir ])] = lit #lit_dir is either lb or ub
def _handle_target_table_token(self, sub_token: TokenList) -> None: if isinstance(sub_token, Function): # insert into tab (col1, col2) values (val1, val2); Here tab (col1, col2) will be parsed as Function # referring https://github.com/andialbrecht/sqlparse/issues/483 for further information if not isinstance(sub_token.token_first(skip_cm=True), Identifier): raise SQLLineageException( "An Identifier is expected, got %s[value: %s] instead" % (type(sub_token).__name__, sub_token)) self._lineage_result.write.add( Table.create(sub_token.token_first(skip_cm=True))) elif isinstance(sub_token, Comparison): # create table tab1 like tab2, tab1 like tab2 will be parsed as Comparison # referring https://github.com/andialbrecht/sqlparse/issues/543 for further information if not (isinstance(sub_token.left, Identifier) and isinstance(sub_token.right, Identifier)): raise SQLLineageException( "An Identifier is expected, got %s[value: %s] instead" % (type(sub_token).__name__, sub_token)) self._lineage_result.write.add(Table.create(sub_token.left)) self._lineage_result.read.add(Table.create(sub_token.right)) else: if not isinstance(sub_token, Identifier): raise SQLLineageException( "An Identifier is expected, got %s[value: %s] instead" % (type(sub_token).__name__, sub_token)) self._lineage_result.write.add(Table.create(sub_token))
def test_group_parentheses(): tokens = [ Token(T.Keyword, 'CREATE'), Token(T.Whitespace, ' '), Token(T.Keyword, 'TABLE'), Token(T.Whitespace, ' '), Token(T.Name, 'table_name'), Token(T.Whitespace, ' '), Token(T.Punctuation, '('), Token(T.Name, 'id'), Token(T.Whitespace, ' '), Token(T.Keyword, 'SERIAL'), Token(T.Whitespace, ' '), Token(T.Keyword, 'CHECK'), Token(T.Punctuation, '('), Token(T.Name, 'id'), Token(T.Operator, '='), Token(T.Number, '0'), Token(T.Punctuation, ')'), Token(T.Punctuation, ')'), Token(T.Punctuation, ';'), ] expected_tokens = TokenList([ Token(T.Keyword, 'CREATE'), Token(T.Keyword, 'TABLE'), Token(T.Name, 'table_name'), Parenthesis([ Token(T.Punctuation, '('), Token(T.Name, 'id'), Token(T.Keyword, 'SERIAL'), Token(T.Keyword, 'CHECK'), Parenthesis([ Token(T.Punctuation, '('), Token(T.Name, 'id'), Token(T.Operator, '='), Token(T.Number, '0'), Token(T.Punctuation, ')'), ]), Token(T.Punctuation, ')'), ]), Token(T.Punctuation, ';'), ]) grouped_tokens = group_parentheses(tokens) stdout = sys.stdout try: sys.stdout = StringIO() expected_tokens._pprint_tree() a = sys.stdout.getvalue() sys.stdout = StringIO() grouped_tokens._pprint_tree() b = sys.stdout.getvalue() finally: sys.stdout = stdout assert_multi_line_equal(a, b)
def filter_identifier_list(tkn_list: TokenList, token: Token): # debug: pprint(token) index = tkn_list.token_index(token) prev_token: Token = tkn_list.token_prev(index)[1] if prev_token is not None: # prev is not exist(index: 0) -> None if not prev_token.match(DML, 'SELECT'): return False next_token: Token = tkn_list.token_next(index)[1] if next_token is not None: # next is not exist(index: list len max) -> None if not next_token.match(Keyword, 'FROM'): return False return True
def extract_from_column(self): ''' columns_group can collect all tokens between 'DML SELECT' and 'Keyword FROM' [<DML 'SELECT' at 0x3655A08>, <Whitespace ' ' at 0x3655A68>, <IdentifierList 'me.Sap...' at 0x366E228>, <Newline ' ' at 0x3665948>, <Keyword 'FROM' at 0x36659A8>, <Whitespace ' ' at 0x3665A08>, <IdentifierList 'SODS2....' at 0x366E390>, <Whitespace ' ' at 0x3667228>, <IdentifierList 't,SHAR...' at 0x366E480>, <Newline ' ' at 0x3667528>] ''' tokens = self.getTokens() tokenlist = TokenList(tokens) cols_idx,cols_item = [] , [] cols_group = [] ''' cols_item only keep the columns between select and from. Notic : exists many groups if sql have union/union all token , so need use cols_group to collect it. ''' fetch_col_flag = False for idx, item in enumerate(tokens): before_idx,before_item = tokenlist.token_prev(idx,skip_ws=True) next_idx,next_item = tokenlist.token_next(idx,skip_ws=True) if not next_item : break #capture up first column index if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \ (before_item.ttype == Keyword.DML or before_item.value.upper() == 'DISTINCT'): cols_idx.append(idx) fetch_col_flag = True cols_item = [] if fetch_col_flag == True: cols_item.append(item) #capture up last column index if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \ next_item.ttype is Keyword and next_item.value.upper() == 'FROM': cols_idx.append(idx) fetch_col_flag = False cols_group.append (''.join([ item.value for item in cols_item])) ''' the cols_idx like [[10,12],[24,26]],it's two-dimnsn list , --> flatten to [10,11,12,24,25,26] ''' cols_idxes = sum([list(range(cols_idx[2*i],cols_idx[2*i+1]+1)) for i in range(int(len(cols_idx)/2))],[]) keep_tokens = [ item for idx,item in enumerate(tokens) if idx not in cols_idxes ] self.tokens = keep_tokens self.tokens_val = [item.value for item in tokens] return cols_group
def __process_tokenlist(self, tlist: TokenList): # exclude subselects if '(' not in str(tlist): table_name = self.__get_full_name(tlist) if table_name and not table_name.startswith(CTE_PREFIX): self._table_names.add(table_name) return # store aliases if tlist.has_alias(): self._alias_names.add(tlist.get_alias()) # some aliases are not parsed properly if tlist.tokens[0].ttype == Name: self._alias_names.add(tlist.tokens[0].value) self.__extract_from_token(tlist)
def _define_primary_key( metadata: AllFieldMetadata, column_definition_group: token_groups.TokenList, ) -> typing.Optional[AllFieldMetadata]: idx, constraint_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "CONSTRAINT")) idx, primary_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "PRIMARY"), idx=(idx or -1)) if constraint_keyword is not None and primary_keyword is None: raise exceptions.NotSupportedError( "When a column definition clause begins with CONSTRAINT, " "only a PRIMARY KEY constraint is supported") if primary_keyword is None: return None # If the keyword isn't followed by column name(s), then it's part of # a regular column definition and should be handled by _define_column if not _contains_column_name(column_definition_group, idx): return None new_metadata: AllFieldMetadata = deepcopy(metadata) while True: idx, primary_key_column = column_definition_group.token_next_by( t=token_types.Name, idx=idx) # 'id' is defined and managed by Fauna, so we ignore any attempts # to manage it from SQLAlchemy if primary_key_column is None or primary_key_column.value == "id": break primary_key_column_name = primary_key_column.value new_metadata[primary_key_column_name] = { **DEFAULT_FIELD_METADATA, # type: ignore **new_metadata.get(primary_key_column_name, {}), # type: ignore "unique": True, "not_null": True, } return new_metadata
def get_query_tokens(query): """ :type query str :rtype: list[sqlparse.sql.Token] """ tokens = TokenList(sqlparse.parse(query)[0].tokens).flatten() # print([(token.value, token.ttype) for token in tokens]) return [token for token in tokens if token.ttype is not Whitespace]
def _extract_limit_from_query(statement: TokenList) -> Optional[int]: """ Extract limit clause from SQL statement. :param statement: SQL statement :return: Limit extracted from query, None if no limit present in statement """ idx, _ = statement.token_next_by(m=(Keyword, "LIMIT")) if idx is not None: _, token = statement.token_next(idx=idx) if token: if isinstance(token, IdentifierList): # In case of "LIMIT <offset>, <limit>", find comma and extract # first succeeding non-whitespace token idx, _ = token.token_next_by(m=(sqlparse.tokens.Punctuation, ",")) _, token = token.token_next(idx=idx) if token and token.ttype == sqlparse.tokens.Literal.Number.Integer: return int(token.value) return None
def to_clickhouse(cls, schema: str, query: str): """ parse ddl query :param schema: :param query: :return: """ token_list = TokenList() parsed = sqlparse.parse(query)[0] token_list = cls._add_token(schema, parsed, parsed.tokens, token_list) return str(token_list)
def get_query_tokens(query: str) -> List[sqlparse.sql.Token]: query = preprocess_query(query) parsed = sqlparse.parse(query) # handle empty queries (#12) if not parsed: return [] tokens = TokenList(parsed[0].tokens).flatten() return [token for token in tokens if token.ttype is not Whitespace]
def parse(cls, sql: str, default_schema: Optional[str] = None) -> SqlMeta: if sql is None: raise ValueError("A sql statement must be provided.") # Tokenize the SQL statement statements = sqlparse.parse(sql) # We assume only one statement in SQL tokens = TokenList(statements[0].tokens) log.debug(f"Successfully tokenized sql statement: {tokens}") parser = cls(default_schema) return parser.recurse(tokens)
def _process_tokenlist(self, token_list: TokenList): """ Add table names to table set :param token_list: TokenList to be processed """ # exclude subselects if "(" not in str(token_list): table = self._get_table(token_list) if table and not table.table.startswith(CTE_PREFIX): self._tables.add(table) return # store aliases if token_list.has_alias(): self._alias_names.add(token_list.get_alias()) # some aliases are not parsed properly if token_list.tokens[0].ttype == Name: self._alias_names.add(token_list.tokens[0].value) self._extract_from_token(token_list)
def _process_statement_tokens(cls, statement_tokens, filter_string): """ This function processes the tokens in a statement to ensure that the correct parsing behavior occurs. In typical cases, the statement tokens will contain just the comparison - in this case, no additional processing occurs. In the case when a filter string contains the IN operator, this function parses those tokens into a Comparison object, which will be parsed by _get_comparison_for_model_registry. :param statement_tokens: List of tokens from a statement :param filter_string: Filter string from which the parsed statement tokens originate. Used for informative logging :return: List of tokens """ expected = "Expected search filter with single comparison operator. e.g. name='myModelName'" token_list = [] if len(statement_tokens) == 0: raise MlflowException( "Invalid filter '%s'. Could not be parsed. %s" % (filter_string, expected), error_code=INVALID_PARAMETER_VALUE, ) elif len(statement_tokens) == 1: if isinstance(statement_tokens[0], Comparison): token_list = statement_tokens else: raise MlflowException( "Invalid filter '%s'. Could not be parsed. %s" % (filter_string, expected), error_code=INVALID_PARAMETER_VALUE, ) elif len(statement_tokens) > 1: comparison_subtokens = [] for token in statement_tokens: if isinstance(token, Comparison): raise MlflowException( "Search filter '%s' contains multiple expressions. " "%s " % (filter_string, expected), error_code=INVALID_PARAMETER_VALUE, ) elif cls._is_list_component_token(token): comparison_subtokens.append(token) elif not token.is_whitespace: break # if we have fewer than 3, that means we have an incomplete statement. if len(comparison_subtokens) == 3: token_list = [Comparison(TokenList(comparison_subtokens))] else: raise MlflowException( "Invalid filter '%s'. Could not be parsed. %s" % (filter_string, expected), error_code=INVALID_PARAMETER_VALUE, ) return token_list
def __projections(self, token, tokenlist): idx = tokenlist.token_index(token) afs_list_idx, afs = tokenlist.token_next(idx, skip_ws=True, skip_cm=True) afs_list = TokenList(list(afs.flatten())) for af in afs_list: # Get AFs if af.value.lower() in ['avg', 'count', 'sum', 'min', 'max']: # if af not in self.afs_dic: # self.afs_dic[af.value] = [] af_idx = afs_list.token_index(af) punc_idx, _ = afs_list.token_next(af_idx, skip_ws=True, skip_cm=True) attr_idx, attr = afs_list.token_next(punc_idx, skip_ws=True, skip_cm=True) if attr.ttype is not Wildcard: self.afs.append('_'.join([af.value, attr.value])) else: self.afs.append(af.value)
def extract_from_column(self): ''' pick up all tokens between 'DML SELECT' and 'Keyword FROM' [<DML 'SELECT' at 0x3655A08>, <Whitespace ' ' at 0x3655A68>, <IdentifierList 'me.Sap...' at 0x366E228>, <Newline ' ' at 0x3665948>, <Keyword 'FROM' at 0x36659A8>, <Whitespace ' ' at 0x3665A08>, <IdentifierList 'SODS2....' at 0x366E390>, <Whitespace ' ' at 0x3667228>, <IdentifierList 't,SHAR...' at 0x366E480>, <Newline ' ' at 0x3667528>] ''' tokens = self.getTokens() tokenlist = TokenList(tokens) cols_idx,cols_item = [] , [] cols_group = [] fetch_col_flag = False for idx, item in enumerate(tokens): before_idx,before_item = tokenlist.token_prev(idx,skip_ws=True) next_idx,next_item = tokenlist.token_next(idx,skip_ws=True) if not next_item : break #capture up first column index if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \ (before_item.ttype == Keyword.DML or before_item.value.upper() == 'DISTINCT'): cols_idx.append(idx) fetch_col_flag = True cols_item = [] if fetch_col_flag == True: cols_item.append(item) #capture up last column index if (isinstance(item,IdentifierList) or isinstance(item,Identifier)) and \ next_item.ttype is Keyword and next_item.value.upper() == 'FROM': cols_idx.append(idx) fetch_col_flag = False cols_group.append (cols_item) cols_idxes = sum([list(range(cols_idx[2*i],cols_idx[2*i+1]+1)) for i in range(int(len(cols_idx)/2))],[]) left_tokens = [ item for idx,item in enumerate(tokens) if idx not in cols_idxes ]
def group_parentheses(tokens): stack = [[]] for token in tokens: if token.is_whitespace: continue if token.match(T.Punctuation, '('): stack.append([token]) else: stack[-1].append(token) if token.match(T.Punctuation, ')'): group = stack.pop() stack[-1].append(Parenthesis(group)) return TokenList(stack[0])
def _define_unique_constraint( metadata: AllFieldMetadata, column_definition_group: token_groups.TokenList, ) -> typing.Optional[AllFieldMetadata]: idx, unique_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "UNIQUE")) if unique_keyword is None: return None # If the keyword isn't followed by column name(s), then it's part of # a regular column definition and should be handled by _define_column if not _contains_column_name(column_definition_group, idx): return None new_metadata = deepcopy(metadata) while True: idx, unique_key_column = column_definition_group.token_next_by( t=token_types.Name, idx=idx) # 'id' is defined and managed by Fauna, so we ignore any attempts # to manage it from SQLAlchemy if unique_key_column is None or unique_key_column.value == "id": break unique_key_column_name = unique_key_column.value new_metadata[unique_key_column_name] = { **DEFAULT_FIELD_METADATA, # type: ignore **new_metadata.get(unique_key_column_name, {}), # type: ignore "unique": True, } return new_metadata
def get_query_tokens(query: str) -> List[sqlparse.sql.Token]: """ :type query str :rtype: list[sqlparse.sql.Token] """ query = preprocess_query(query) parsed = sqlparse.parse(query) # handle empty queries (#12) if not parsed: return [] tokens = TokenList(parsed[0].tokens).flatten() # print([(token.value, token.ttype) for token in tokens]) return [token for token in tokens if token.ttype is not Whitespace]
def _define_column( metadata: AllFieldMetadata, column_definition_group: token_groups.TokenList, ) -> AllFieldMetadata: idx, column = column_definition_group.token_next_by(t=token_types.Name) column_name = column.value # "id" is auto-generated by Fauna, so we ignore it in SQL column definitions if column_name == "id": return metadata idx, data_type = column_definition_group.token_next_by(t=token_types.Name, idx=idx) _, not_null_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "NOT NULL")) _, unique_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "UNIQUE")) _, primary_key_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "PRIMARY KEY")) _, default_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "DEFAULT")) _, check_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "CHECK")) if check_keyword is not None: raise exceptions.NotSupportedError("CHECK keyword is not supported.") column_metadata: typing.Union[FieldMetadata, typing.Dict[str, str]] = metadata.get( column_name, {}) is_primary_key = primary_key_keyword is not None is_not_null = (not_null_keyword is not None or is_primary_key or column_metadata.get("not_null") or False) is_unique = (unique_keyword is not None or is_primary_key or column_metadata.get("unique") or False) default_value = (default_keyword if default_keyword is None else sql.extract_value(default_keyword.value)) return { **metadata, column_name: { **DEFAULT_FIELD_METADATA, # type: ignore **metadata.get(column_name, EMPTY_DICT), # type: ignore "unique": is_unique, "not_null": is_not_null, "default": default_value, "type": DATA_TYPE_MAP[data_type.value], }, }
def tokens(self) -> List[SQLToken]: """ Tokenizes the query """ if self._tokens is not None: return self._tokens parsed = sqlparse.parse(self.query) tokens = [] # handle empty queries (#12) if not parsed: return tokens sqlparse_tokens = TokenList(parsed[0].tokens).flatten() non_empty_tokens = [ token for token in sqlparse_tokens if token.ttype is not Whitespace ] last_keyword = None for index, tok in enumerate(non_empty_tokens): token = SQLToken( tok=tok, index=index, subquery_level=self._subquery_level, last_keyword=last_keyword, ) if index > 0: # create links between consecutive tokens token.previous_token = tokens[index - 1] tokens[index - 1].next_token = token if token.is_left_parenthesis: self._determine_opening_parenthesis_type(token=token) elif token.is_right_parenthesis: self._determine_closing_parenthesis_type(token=token) if tok.is_keyword and tok.normalized not in KEYWORDS_IGNORED: last_keyword = tok.normalized token.is_in_nested_function = self._is_in_nested_function tokens.append(token) self._tokens = tokens return tokens
def parse(cls, sql: str, default_schema: Optional[str] = None) -> SqlMeta: if sql is None: raise ValueError("A sql statement must be provided.") # Tokenize the SQL statement sql_statements = sqlparse.parse(sql) sql_parser = cls(default_schema) sql_meta = SqlMeta([], []) for sql_statement in sql_statements: tokens = TokenList(sql_statement.tokens) log.debug(f"Successfully tokenized sql statement: {tokens}") result = sql_parser.recurse(tokens) # Add the in / out tables (if any) to the sql meta sql_meta.add_in_tables(result.in_tables) sql_meta.add_out_tables(result.out_tables) return sql_meta
def _handle_source_table_token(self, sub_token: TokenList) -> None: if isinstance(sub_token, Identifier): if isinstance(sub_token.token_first(skip_cm=True), Parenthesis): # SELECT col1 FROM (SELECT col2 FROM tab1) dt, the subquery will be parsed as Identifier # and this Identifier's get_real_name method would return alias name dt # referring https://github.com/andialbrecht/sqlparse/issues/218 for further information pass else: self._lineage_result.read.add(Table.create(sub_token)) elif isinstance(sub_token, IdentifierList): # This is to support join in ANSI-89 syntax for token in sub_token.tokens: if isinstance(token, Identifier): self._lineage_result.read.add(Table.create(token)) elif isinstance(sub_token, Parenthesis): # SELECT col1 FROM (SELECT col2 FROM tab1), the subquery will be parsed as Parenthesis # This syntax without alias for subquery is invalid in MySQL, while valid for SparkSQL pass else: raise SQLLineageException( "An Identifier is expected, got %s[value: %s] instead" % (type(sub_token).__name__, sub_token))
def _define_foreign_key_constraint( metadata: AllFieldMetadata, column_definition_group: token_groups.TokenList ) -> typing.Optional[AllFieldMetadata]: idx, foreign_keyword = column_definition_group.token_next_by( m=(token_types.Keyword, "FOREIGN")) if foreign_keyword is None: return None idx, _ = column_definition_group.token_next_by(m=(token_types.Name, "KEY"), idx=idx) idx, foreign_key_column = column_definition_group.token_next_by( t=token_types.Name, idx=idx) column_name = foreign_key_column.value idx, _ = column_definition_group.token_next_by(m=(token_types.Keyword, "REFERENCES"), idx=idx) idx, reference_table = column_definition_group.token_next_by( t=token_types.Name, idx=idx) reference_table_name = reference_table.value idx, reference_column = column_definition_group.token_next_by( t=token_types.Name, idx=idx) reference_column_name = reference_column.value if any( metadata.get(column_name, EMPTY_DICT).get("references", EMPTY_DICT)): raise exceptions.NotSupportedError( "Foreign keys with multiple references are not currently supported." ) if reference_column_name != "id": raise exceptions.NotSupportedError( "Foreign keys referring to fields other than ID are not currently supported." ) return { **metadata, column_name: { **DEFAULT_FIELD_METADATA, # type: ignore **metadata.get(column_name, EMPTY_DICT), "references": { reference_table_name: reference_column_name }, }, }
def convert_expression_to_python(token): if not token.is_group: if token.value.upper() == 'TRUE': return 'sql.true()' elif token.value.upper() == 'FALSE': return 'sql.false()' elif token.ttype == T.Name: return 'sql.literal_column({0!r})'.format(str(token.value)) else: return 'sql.text({0!r})'.format(str(token.value)) if isinstance(token, Parenthesis): return '({0})'.format(convert_expression_to_python(TokenList(token.tokens[1:-1]))) elif len(token.tokens) == 1: return convert_expression_to_python(token.tokens[0]) elif len(token.tokens) == 3 and token.tokens[1].ttype == T.Comparison: lhs = convert_expression_to_python(token.tokens[0]) rhs = convert_expression_to_python(token.tokens[2]) op = token.tokens[1].value if op == '=': op = '==' return '{0} {1} {2}'.format(lhs, op, rhs) elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IN') and isinstance(token.tokens[2], Parenthesis): lhs = convert_expression_to_python(token.tokens[0]) rhs = [convert_expression_to_python(t) for t in token.tokens[2].tokens[1:-1] if not t.match(T.Punctuation, ',')] return '{0}.in_({1!r})'.format(lhs, tuple(rhs)) elif len(token.tokens) == 4 and token.tokens[1].match(T.Comparison, '~') and token.tokens[2].match(T.Name, 'E') and token.tokens[3].ttype == T.String.Single: lhs = convert_expression_to_python(token.tokens[0]) pattern = token.tokens[3].value.replace('\\\\', '\\') return 'regexp({0}, {1})'.format(lhs, pattern) elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IS') and token.tokens[2].match(T.Keyword, 'NULL'): lhs = convert_expression_to_python(token.tokens[0]) return '{0} == None'.format(lhs) elif len(token.tokens) == 3 and token.tokens[1].match(T.Keyword, 'IS') and token.tokens[2].match(T.Keyword, 'NOT NULL'): lhs = convert_expression_to_python(token.tokens[0]) return '{0} != None'.format(lhs) else: parts = [] op = None idx = -1 while True: new_idx, op_token = token.token_next_by(m=(T.Keyword, ('AND', 'OR')), idx=idx) if op_token is None: break if op is None: op = op_token.normalized assert op == op_token.normalized new_tokens = token.tokens[idx+1:new_idx] if len(new_tokens) == 1: parts.append(convert_expression_to_python(new_tokens[0])) else: parts.append(convert_expression_to_python(TokenList(new_tokens))) idx = new_idx + 1 if idx == -1: raise ValueError('unknown expression - {0}'.format(token)) new_tokens = token.tokens[idx:] if len(new_tokens) == 1: parts.append(convert_expression_to_python(new_tokens[0])) else: parts.append(convert_expression_to_python(TokenList(new_tokens))) return 'sql.{0}_({1})'.format(op.lower(), ', '.join(parts))