def _flatten_sqlparse(self): for token in self.sqlparse_tokens: # sqlparse returns mysql digit starting identifiers as group # check https://github.com/andialbrecht/sqlparse/issues/337 is_grouped_mysql_digit_name = ( token.is_group and len(token.tokens) == 2 and token.tokens[0].ttype is Number.Integer and ( token.tokens[1].is_group and token.tokens[1].tokens[0].ttype is Name ) ) if token.is_group and not is_grouped_mysql_digit_name: yield from token.flatten() elif is_grouped_mysql_digit_name: # we have digit starting name new_tok = Token( value=f"{token.tokens[0].normalized}" f"{token.tokens[1].tokens[0].normalized}", ttype=token.tokens[1].tokens[0].ttype, ) new_tok.parent = token.parent yield new_tok if len(token.tokens[1].tokens) > 1: # unfortunately there might be nested groups remaining_tokens = token.tokens[1].tokens[1:] for tok in remaining_tokens: if tok.is_group: yield from tok.flatten() else: yield tok else: yield token
def sql_recursively_simplify(node, hide_columns=True): # Erase which fields are being updated in an UPDATE if node.tokens[0].value == "UPDATE": i_set = [i for (i, t) in enumerate(node.tokens) if t.value == "SET"][0] i_where = [ i for (i, t) in enumerate(node.tokens) if t.is_group and t.tokens[0].value == "WHERE" ][0] middle = [Token(tokens.Punctuation, " ... ")] node.tokens = node.tokens[:i_set + 1] + middle + node.tokens[i_where:] # Erase the names of savepoints since they are non-deteriministic if hasattr(node, "tokens"): # SAVEPOINT x if str(node.tokens[0]) == "SAVEPOINT": node.tokens[2].tokens[0].value = "`#`" return # RELEASE SAVEPOINT x elif len(node.tokens) >= 3 and node.tokens[2].value == "SAVEPOINT": node.tokens[4].tokens[0].value = "`#`" return # ROLLBACK TO SAVEPOINT X token_values = [getattr(t, "value", "") for t in node.tokens] if len(node.tokens) == 7 and token_values[:6] == [ "ROLLBACK", " ", "TO", " ", "SAVEPOINT", " ", ]: node.tokens[6].tokens[0].value = "`#`" return # Erase volatile part of PG cursor name if node.tokens[0].value.startswith('"_django_curs_'): node.tokens[0].value = '"_django_curs_#"' prev_word_token = None for token in node.tokens: ttype = getattr(token, "ttype", None) # Detect IdentifierList tokens within an ORDER BY, GROUP BY or HAVING # clauses inside_order_group_having = match_keyword( prev_word_token, ["ORDER BY", "GROUP BY", "HAVING"]) replace_columns = not inside_order_group_having and hide_columns if isinstance(token, IdentifierList) and replace_columns: token.tokens = [Token(tokens.Punctuation, "...")] elif hasattr(token, "tokens"): sql_recursively_simplify(token, hide_columns=hide_columns) elif ttype in sql_deleteable_tokens: token.value = "#" elif getattr(token, "value", None) == "NULL": token.value = "#" if not token.is_whitespace: prev_word_token = token
def handle_read_operation( self, idx: int, next_token: Token, columns: Columns, dynamic_tables: Dict[str, Schema], ) -> Tuple[Schema, int, Columns]: # looking for a sub-query first - it can be the current token or the first son of this token extracted = {} if isinstance(next_token, Parenthesis): # sub query doesnt have an alias name: # `... from (<sub_query>) where ...` extracted = self.extract_from_subquery(next_token, "anon", columns) elif isinstance(next_token, TokenList) and isinstance( next_token.token_first(), Parenthesis ): # the sub query has an alias name: # `... from (<sub_query>) as <alias_name> where ...` extracted = self.extract_from_subquery( next_token.token_first(), next_token.get_alias(), columns ) elif next_token.ttype not in Keyword: # no subquery - just parse the source/dest name of the operator extracted, left_columns = self.generate_schema(columns, next_token) columns = left_columns extracted = self.enrich_with_dynamic(extracted, dynamic_tables) return extracted, idx, columns
def _add_tenant_to_sql(rule_sql): prefix = f'/+/{g.tenant_uid}' parsed = sqlparse.parse(rule_sql) stmt = parsed[0] token_dict = {} for token in stmt.tokens: if isinstance(token, Identifier): if re.match(r'^"/.*', token.value): # FROM "/productID/deviceID/" index = stmt.token_index(token) new_value = f'\"{prefix}{token.value[1:]}' token_dict[index] = new_value else: # SELECT getMetadataPropertyValue('/productID/deviceID/','topic') as topic FROM new_value = _replace_func_sql(prefix, token.value) if new_value: index = stmt.token_index(token) token_dict[index] = new_value # SELECT getMetadataPropertyValue('/productID/deviceID/','topic') as topic,* FROM if isinstance(token, IdentifierList): for index, identifier in enumerate(token.get_identifiers()): new_value = _replace_func_sql(prefix, identifier.value) if new_value: token.tokens[index] = Token(None, new_value) for index, value in token_dict.items(): token = Token(None, value) stmt.tokens[index] = token return str(stmt)
def get_value(self, tok: Token): if tok.ttype == tokens.Name.Placeholder: return self.placeholder_index(tok) elif tok.match(tokens.Keyword, 'NULL'): return None elif tok.match(tokens.Keyword, 'DEFAULT'): return 'DEFAULT' else: raise SQLDecodeError
def _token2op(self, tok: Token, statement: SQLStatement) -> '_Op': op = None kw = {'statement': statement, 'query': self.query} if tok.match(tokens.Keyword, 'AND'): op = AndOp(**kw) elif tok.match(tokens.Keyword, 'OR'): op = OrOp(**kw) elif tok.match(tokens.Keyword, 'IN'): op = InOp(**kw) elif tok.match(tokens.Keyword, 'NOT'): if statement.next_token.match(tokens.Keyword, 'IN'): op = NotInOp(**kw) statement.skip(1) else: op = NotOp(**kw) elif tok.match(tokens.Keyword, 'LIKE'): op = LikeOp(**kw) elif tok.match(tokens.Keyword, 'iLIKE'): op = iLikeOp(**kw) elif tok.match(tokens.Keyword, 'BETWEEN'): op = BetweenOp(**kw) statement.skip(3) elif tok.match(tokens.Keyword, 'IS'): op = IsOp(**kw) elif isinstance(tok, Comparison): op = CmpOp(tok, self.query) elif isinstance(tok, Parenthesis): if (tok[1].match(tokens.Name.Placeholder, '.*', regex=True) or tok[1].match(tokens.Keyword, 'Null') or isinstance(tok[1], IdentifierList) or tok[1].ttype == tokens.DML ): pass else: op = ParenthesisOp(SQLStatement(tok), self.query) elif tok.match(tokens.Punctuation, (')', '(')): pass elif isinstance(tok, Identifier): pass else: raise SQLDecodeError return op
def token2sql( token: Token, query: 'query_module.BaseQuery' ) -> U['CountFuncAll', 'CountFuncSingle']: try: ## FIX: COUNT(DISTINCT COL) ## TODO: This just gets the parser through the token, but distinct logic is not actually handled yet. if isinstance(token[0], Identifier): token.get_parameters()[0] else: token[0].get_parameters()[0] except IndexError: return CountFuncAll(token, query) else: return CountFuncSingle(token, query)
def handle_into( self, idx: int, next_token: Token, statement: TokenList ) -> Tuple[Schema, int]: table_alias = self.get_identifier_name(next_token) table_name = self.get_full_name(next_token) if isinstance(next_token, Function): extracted = self.extract_write_schema( next_token.get_parameters(), table_alias, table_alias ) else: extracted = self.extract_write_schema( [Token(Wildcard, "*")], table_name, table_alias ) return extracted, idx
def sql_recursively_simplify(node): # Erase which fields are being updated in an UPDATE if node.tokens[0].value == 'UPDATE': i_set = [i for (i, t) in enumerate(node.tokens) if t.value == 'SET'][0] i_where = [ i for (i, t) in enumerate(node.tokens) if _is_group(t) and t.tokens[0].value == 'WHERE' ][0] middle = [Token(tokens.Punctuation, ' ... ')] node.tokens = node.tokens[:i_set + 1] + middle + node.tokens[i_where:] # Erase the names of savepoints since they are non-deteriministic if hasattr(node, 'tokens'): # SAVEPOINT x if str(node.tokens[0]) == 'SAVEPOINT': node.tokens[2].tokens[0].value = '`#`' return # RELEASE SAVEPOINT x elif len(node.tokens) >= 3 and node.tokens[2].value == 'SAVEPOINT': node.tokens[4].tokens[0].value = "`#`" return # ROLLBACK TO SAVEPOINT X token_values = [getattr(t, 'value', '') for t in node.tokens] if len(node.tokens) == 7 and token_values[:6] == [ 'ROLLBACK', ' ', 'TO', ' ', 'SAVEPOINT', ' ' ]: node.tokens[6].tokens[0].value = '`#`' return # Erase volatile part of PG cursor name if node.tokens[0].value.startswith('"_django_curs_'): node.tokens[0].value = '"_django_curs_#"' for i, token in enumerate(node.tokens): ttype = getattr(token, 'ttype', None) if isinstance(token, IdentifierList): token.tokens = [Token(tokens.Punctuation, '...')] elif hasattr(token, 'tokens'): sql_recursively_simplify(token) elif ttype in sql_deleteable_tokens: token.value = '#' elif ttype == tokens.Whitespace.Newline: token.value = '' # Erase newlines elif ttype == tokens.Whitespace: token.value = ' ' elif getattr(token, 'value', None) == 'NULL': token.value = '#'
def test_simple_join2(self): data = [JoinData(left_key="aa", right_key="ab"), JoinData(left_key="ec", right_key="eb")] table_a = Table(table_name="A", columns=[Column(name="aa", column_type=TypeEnum.int), Column(name="b", column_type=TypeEnum.int), Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""], annotations=[]) table_b = Table(table_name="B", columns=[Column(name="ab", column_type=TypeEnum.int), Column(name="eb", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""], annotations=[]) table_c = Table(table_name="C", columns=[Column(name="ec", column_type=TypeEnum.int), Column(name="f", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""], annotations=[]) tables = [table_a, table_b, table_c] root = SelectNode(tables=tables, annotation_name="demo") root.set_identifier_list([Identifier(tokens=[Token(None, "ec")]), Identifier(tokens=[Token(None, "f")])]) root.next = JoinNode(join_list=data, tables=tables) root.next.prev = root root.next.merge() result = root.next.to_code(table_a.get_root()) self.assertTrue(len(result) > 0)
def handle_into(self, idx: int, next_token: Token, statement: TokenList) -> Tuple[Schema, int]: if next_token.is_group: # if token is group, get inner function token if exist for token in next_token.tokens: if isinstance(token, Function): next_token = token table_alias = self.get_identifier_name(next_token) table_name = self.get_full_name(next_token) if isinstance(next_token, Function): extracted = self.extract_write_schema(next_token.get_parameters(), table_alias, table_alias) else: extracted = self.extract_write_schema([Token(Wildcard, "*")], table_name, table_alias) return extracted, idx
def add_table_name(rls: TokenList, table: str) -> None: """ Modify a RLS expression inplace ensuring columns are fully qualified. """ tokens = rls.tokens[:] while tokens: token = tokens.pop(0) if isinstance(token, Identifier) and token.get_parent_name() is None: token.tokens = [ Token(Name, table), Token(Punctuation, "."), Token(Name, token.get_name()), ] elif isinstance(token, TokenList): tokens.extend(token.tokens)
def tokens2sql(token: Token, query: 'query_module.BaseQuery' ) -> Iterator[all_token_types]: from .functions import SQLFunc if isinstance(token, Identifier): # Bug fix for sql parse if isinstance(token[0], Parenthesis): try: int(token[0][1].value) except ValueError: yield SQLIdentifier(token[0][1], query) else: yield SQLConstIdentifier(token, query) elif isinstance(token[0], Function): yield SQLFunc.token2sql(token, query) else: yield SQLIdentifier(token, query) elif isinstance(token, Function): yield SQLFunc.token2sql(token, query) elif isinstance(token, Comparison): yield SQLComparison(token, query) elif isinstance(token, IdentifierList): for tok in token.get_identifiers(): yield from SQLToken.tokens2sql(tok, query) elif isinstance(token, Parenthesis): yield SQLPlaceholder(token, query) else: raise SQLDecodeError(f'Unsupported: {token.value}')
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None: """ Tests for ``get_rls_for_table``. """ candidate = Identifier([Token(Name, "some_table")]) db = mocker.patch("superset.db") dataset = db.session.query().filter().one_or_none() dataset.__str__.return_value = "some_table" dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")] assert ( str(get_rls_for_table(candidate, 1, "public")) == "some_table.organization_id = 1" ) dataset.get_sqla_row_level_filters.return_value = [ text("organization_id = 1"), text("foo = 'bar'"), ] assert ( str(get_rls_for_table(candidate, 1, "public")) == "some_table.organization_id = 1 AND some_table.foo = 'bar'" ) dataset.get_sqla_row_level_filters.return_value = [] assert get_rls_for_table(candidate, 1, "public") is None
def process(self, stack, stream): splitlevel = 0 stmt = None consume_ws = False stmt_tokens = [] for ttype, value in stream: # Before appending the token if (consume_ws and ttype is not T.Whitespace and ttype is not T.Comment.Single): consume_ws = False stmt.tokens = stmt_tokens yield stmt self._reset() stmt = None splitlevel = 0 if stmt is None: stmt = Statement() stmt_tokens = [] splitlevel += self._change_splitlevel(ttype, value) # Append the token stmt_tokens.append(Token(ttype, value)) # After appending the token if (splitlevel <= 0 and ttype is T.Punctuation and value == ';'): consume_ws = True if stmt is not None: stmt.tokens = stmt_tokens yield stmt
def handle_select(self, idx: int, next_token: Token) -> Tuple[Columns, int]: tokens: List[Token] = ( next_token.get_identifiers() if isinstance(next_token, IdentifierList) else [next_token] ) columns = self._extract_columns(tokens) return columns, idx
def _extract_from_token(self, token: Token) -> None: """ <Identifier> store a list of subtokens and <IdentifierList> store lists of subtoken list. It extracts <IdentifierList> and <Identifier> from :param token: and loops through all subtokens recursively. It finds table_name_preceding_token and passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate self._tables. :param token: instance of Token or child class, e.g. TokenList, to be processed """ if not hasattr(token, "tokens"): return table_name_preceding_token = False # If the table name is a reserved word (eg, "table_name") it won't be returned. We # fix this by ensuring that at least one identifier is returned after the FROM # before stopping on a keyword. has_processed_identifier = False for item in token.tokens: if item.is_group and ( not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) ): self._extract_from_token(item) if item.ttype in Keyword and ( item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(" JOIN") ): table_name_preceding_token = True continue # If we haven't processed any identifiers it means the table name is a # reserved keyword (eg, "table_name") and we shouldn't skip it. if item.ttype in Keyword and has_processed_identifier: table_name_preceding_token = False continue if table_name_preceding_token: if isinstance(item, Identifier): self._process_tokenlist(item) has_processed_identifier = True elif isinstance(item, IdentifierList): for token2 in item.get_identifiers(): if isinstance(token2, TokenList): self._process_tokenlist(token2) has_processed_identifier = True elif item.ttype in Keyword: # convert into an identifier fixed = Identifier([Token(Name, item.value)]) self._process_tokenlist(fixed) has_processed_identifier = True elif isinstance(item, IdentifierList): if any(not self._is_identifier(token2) for token2 in item.tokens): self._extract_from_token(item)
def __parse_from__(self): from_tables = [ Identifier(tokens=[Token(value=f.variable_table_name, ttype="")]) for f in self.db_driver.perform_select_from() ] from_node = FromNode(tables=self.tables) from_node.set_identifier_list(from_tables) last = self.root.get_last_node() last.next = from_node last.next.prev = last
def test_simple_select(self): select_node = SelectNode(tables=[self.table_a], annotation_name="demo") select_node.from_tables = [ Identifier(tokens=[Token("int", "A")]), ] code = select_node.to_code(self.table_a.get_root()) print(code) self.assert_content_in_arr(code, "Relation a(a_ri, a_ai);") self.assert_content_in_arr(code, "CLIENT") self.assert_content_in_arr(code, "{ Relation::INT,Relation::STRING }")
def test_select2(self): table_a = PythonFreeConnexTable(table_name="A", columns=[ Column(name="a", column_type=TypeEnum.int), Column(name="b", column_type=TypeEnum.string) ], owner=CharacterEnum.client, data_sizes=[100], data_paths=[""], annotations=[]) table_b = PythonFreeConnexTable(table_name="B", columns=[ Column(name="c", column_type=TypeEnum.int), Column(name="d", column_type=TypeEnum.string) ], owner=CharacterEnum.server, data_sizes=[100], data_paths=[""], annotations=[]) node = SelectNodePython(tables=[table_a, table_b], annotation_name="demo") node.from_tables = [ Identifier(tokens=[Token("int", "A")]), Identifier(tokens=[Token("int", "B")]), ] relations = node.to_code(table_a.get_root(), should_load_data=False) self.assertEqual(len(relations), 2) table = relations[0] self.assertEqual(table.variable_table_name, "a") self.assertIsNotNone(table.relation) table = relations[1] self.assertEqual(table.variable_table_name, "b") self.assertIsNotNone(table.relation)
def add_pidal(self, value: int): if getattr(self, "pidal_c", None): return self.pidal_c = value for i in self.table_f.tokens: if isinstance(i, Parenthesis): for j in i.tokens: if isinstance(j, IdentifierList): j.insert_after(len(j.tokens), Token(token.Punctuation, ",")) j.insert_after(len(j.tokens), sqlparse.parse("pidal_c")[0].tokens[0]) for i in self.values.tokens: if isinstance(i, Parenthesis): for j in i.tokens: if isinstance(j, IdentifierList): j.insert_after(len(j.tokens), Token(token.Punctuation, ",")) j.insert_after(len(j.tokens), Token(token.Number.Integer, value))
def token2sql( token: Token, query: 'query_module.BaseQuery') -> U['CountFunc', 'SimpleFunc']: ## FIX: agg(distinct) token.is_agg_distinct = isinstance(token[0], Identifier) and token[1][1].match( tokens.Keyword, 'DISTINCT') func = token[0].get_name() if func == 'COUNT': return CountFunc.token2sql(token, query) else: return SimpleFunc(token, query)
def matches_table_name(candidate: Token, table: str) -> bool: """ Returns if the token represents a reference to the table. Tables can be fully qualified with periods. Note that in theory a table should be represented as an identifier, but due to sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets classified as a keyword. """ if not isinstance(candidate, Identifier): candidate = Identifier([Token(Name, candidate.value)]) target = sqlparse.parse(table)[0].tokens[0] if not isinstance(target, Identifier): target = Identifier([Token(Name, target.value)]) # match from right to left, splitting on the period, eg, schema.table == table for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]): if left.value != right.value: return False return True
def build_query(self, sql, limit=2): """ Parses sql query and changes LIMIT statement value. """ import re from sqlparse import parse, tokens from sqlparse.sql import Token # It's important to have a whitespace right after every LIMIT pattern = re.compile('limit([^ ])', re.IGNORECASE) sql = pattern.sub(r'LIMIT \1', sql) query = parse(sql.rstrip(';'))[0] # Find LIMIT statement token = query.token_next_match(0, tokens.Keyword, 'LIMIT') if token: # Find and replace LIMIT value value = query.token_next(query.token_index(token), skip_ws=True) if value: new_token = Token(value.ttype, str(limit)) query.tokens[query.token_index(value)] = new_token else: # If limit is not found, append one new_tokens = [ Token(tokens.Whitespace, ' '), Token(tokens.Keyword, 'LIMIT'), Token(tokens.Whitespace, ' '), Token(tokens.Number, str(limit)), ] last_token = query.tokens[-1] if last_token.ttype == tokens.Punctuation: query.tokens.remove(last_token) for new_token in new_tokens: query.tokens.append(new_token) return str(query)
def _execute_copy_from_local_sql(sql_tree, cursor): # Search for 'LOCAL' keyword for i, token in enumerate(sql_tree.tokens): if token.is_keyword and token.value.lower() == 'local': break file_path = sql_tree.tokens[i + 2].value.strip('\'"') # Replace "LOCAL <file_path>" with "stdin" sql_tree.tokens = sql_tree.tokens[0:i] + [Token(_Token.Keyword, 'stdin') ] + sql_tree.tokens[i + 3:] new_sql = sql_tree.to_unicode() cursor.flush_to_query_ready() with open(file_path, 'rb') as f: cursor.copy(new_sql, f) cursor.flush_to_query_ready()
def get_rls_for_table( candidate: Token, database_id: int, default_schema: Optional[str], username: Optional[str] = None, ) -> Optional[TokenList]: """ Given a table name, return any associated RLS predicates. """ # pylint: disable=import-outside-toplevel from superset import db from superset.connectors.sqla.models import SqlaTable if not isinstance(candidate, Identifier): candidate = Identifier([Token(Name, candidate.value)]) table = ParsedQuery.get_table(candidate) if not table: return None dataset = (db.session.query(SqlaTable).filter( and_( SqlaTable.database_id == database_id, SqlaTable.schema == (table.schema or default_schema), SqlaTable.table_name == table.table, )).one_or_none()) if not dataset: return None template_processor = dataset.get_template_processor() predicate = " AND ".join( str(filter_) for filter_ in dataset.get_sqla_row_level_filters( template_processor, username)) if not predicate: return None rls = sqlparse.parse(predicate)[0] add_table_name(rls, str(dataset)) return rls
def process(self, stack, stream): "Process the stream" consume_ws = False splitlevel = 0 stmt = None stmt_tokens = [] # Run over all stream tokens for ttype, value in stream: # Yield token if we finished a statement and there's no whitespaces if consume_ws and ttype not in (T.Whitespace, T.Comment.Single): stmt.tokens = stmt_tokens yield stmt # Reset filter and prepare to process next statement self._reset() consume_ws = False splitlevel = 0 stmt = None # Create a new statement if we are not currently in one of them if stmt is None: stmt = Statement() stmt_tokens = [] # Change current split level (increase, decrease or remain equal) splitlevel += self._change_splitlevel(ttype, value) # Append the token to the current statement stmt_tokens.append(Token(ttype, value)) # Check if we get the end of a statement if splitlevel <= 0 and ttype is T.Punctuation and value == ';': consume_ws = True # Yield pending statement (if any) if stmt is not None: stmt.tokens = stmt_tokens yield stmt
def test_simple_join1(self): data = [JoinData(left_key="aa", right_key="ba")] table_a = Table(table_name="A", columns=[Column(name="aa", column_type=TypeEnum.int), Column(name="b", column_type=TypeEnum.int), Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""], annotations=[]) table_b = Table(table_name="B", columns=[Column(name="ba", column_type=TypeEnum.int), Column(name="e", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""], annotations=[]) root = SelectNode(tables=[table_a, table_b], annotation_name="demo") root.set_identifier_list([Identifier(tokens=[Token(None, "b")]), Identifier(tokens=[Token(None, "c")])]) root.next = JoinNode(join_list=data, tables=[table_a, table_b]) root.next.prev = root root.next.merge() result = root.next.to_code(table_a.get_root()) self.assertTrue('a.Aggregate({ "aa" });' in result[0])
def handle_update( self, idx: int, next_token: Token, statement: TokenList ) -> Tuple[Schema, int]: table_alias = next_token.get_name() table_name = next_token.normalized idx, token = self._next_non_empty_token(idx, statement) operation_name = token.value.upper() extracted = {} if operation_name == "SET": idx, token = self._next_non_empty_token(idx, statement) if isinstance(token, Comparison): # This is case when token is Comparison object due to single SET condition in SQL extracted = self.extract_write_schema([token], table_name, table_alias) else: extracted = self.extract_write_schema( token.get_identifiers(), table_name, table_alias ) return extracted, idx
def _token2op(self, tok: Token, statement: SQLStatement) -> '_Op': op = None kw = {'statement': statement, 'query': self.query} if tok.match(tokens.Name.Placeholder, '%(.+)s', regex=True): return op if tok.match(tokens.Keyword, 'AND'): op = AndOp(**kw) elif tok.match(tokens.Keyword, 'OR'): op = OrOp(**kw) elif any( t.match(tokens.Comparison, 'IN') for t in [tok, *getattr(tok, 'tokens', [])]): op = InOp(**kw, token='current_token') elif tok.match(tokens.Keyword, 'NOT'): if statement.next_token.match(tokens.Keyword, 'IN'): op = NotInOp(**kw) statement.skip(1) else: op = NotOp(**kw) elif tok.value.endswith("REGEXP"): op = RegexpOp(**kw) elif isinstance(tok, Comparison) and 'LIKE' in tok.normalized: op = LikeOp(**kw) elif isinstance(tok, Comparison) and 'iLIKE' in tok.normalized: op = iLikeOp(**kw) elif tok.match(tokens.Keyword, 'BETWEEN'): op = BetweenOp(**kw) statement.skip(3) elif tok.match(tokens.Keyword, 'IS'): op = IsOp(**kw) elif tok.value in JSON_OPERATORS: op = JSONOp(**kw) elif isinstance(tok, Comparison): op = CmpOp(tok, self.query) elif isinstance(tok, Parenthesis): if (tok[1].match(tokens.Name.Placeholder, '.*', regex=True) or tok[1].match(tokens.Keyword, 'Null') or isinstance(tok[1], IdentifierList) or tok[1].ttype == tokens.DML): pass else: op = ParenthesisOp(SQLStatement(tok), self.query) elif tok.match(tokens.Punctuation, (')', '(')): pass elif isinstance(tok, Identifier): t = statement.next_token if not t or t.match(tokens.Punctuation, (')', '(')) or t.match(tokens.Keyword, ('AND', 'OR')): op = ColOp(tok, self.query) else: raise SQLDecodeError return op