def build_match( col: str, ops: Sequence[str], param_type: Any, alias: Optional[str] = None ) -> Or[Expression]: # The IN condition has to be checked separately since each parameter # has to be checked individually. alias_match = AnyOptionalString() if alias is None else String(alias) column_match = Param("column", ColumnPattern(alias_match, String(col))) return Or( [ FunctionCallPattern( Or([String(op) for op in ops]), (column_match, Param("rhs", LiteralPattern(AnyPattern(param_type)))), ), FunctionCallPattern( String(ConditionFunctions.IN), ( column_match, Param( "rhs", FunctionCallPattern( Or([String("array"), String("tuple")]), all_parameters=LiteralPattern(AnyPattern(param_type)), ), ), ), ), ] )
def __init__(self, uuid_columns: Set[str]) -> None: self.__unique_uuid_columns = uuid_columns self.__uuid_column_match = Or([String(u_col) for u_col in uuid_columns]) self.uuid_in_condition = FunctionCallMatch( Or((String(ConditionFunctions.IN), String(ConditionFunctions.NOT_IN))), ( self.formatted_uuid_pattern(), Param("params", FunctionCallMatch(String("tuple"), None)), ), ) self.uuid_condition = FunctionCallMatch( Or( [ String(op) for op in FUNCTION_TO_OPERATOR if op not in (ConditionFunctions.IN, ConditionFunctions.NOT_IN) ] ), ( Or( ( Param("literal_0", LiteralMatch(AnyOptionalString())), self.formatted_uuid_pattern("_0"), ) ), Or( ( Param("literal_1", LiteralMatch(AnyOptionalString())), self.formatted_uuid_pattern("_1"), ) ), ), ) self.formatted: Optional[str] = None
def condition_pattern( operators: Set[str], lhs_pattern: Pattern[Expression], rhs_pattern: Pattern[Expression], commutative: bool, ) -> Pattern[Expression]: """ Matches a binary condition given the two operands and the valid operators. It also supports commutative conditions. """ pattern: Pattern[Expression] if commutative: pattern = Or( [ FunctionCallPattern( Or([String(op) for op in operators]), (lhs_pattern, rhs_pattern) ), FunctionCallPattern( Or([String(op) for op in operators]), (rhs_pattern, lhs_pattern) ), ] ) else: pattern = FunctionCallPattern( Or([String(op) for op in operators]), (lhs_pattern, rhs_pattern) ) return pattern
def __init__(self, column_name: str, hash_map_name: str, killswitch: str) -> None: self.__column_name = column_name self.__hash_map_name = hash_map_name self.__killswitch = killswitch # TODO: Add the support for IN conditions. self.__optimizable_pattern = FunctionCall( function_name=String("equals"), parameters=( Or( [ mapping_pattern, FunctionCall( function_name=String("ifNull"), parameters=(mapping_pattern, Literal(String(""))), ), ] ), Param("right_hand_side", Literal(Any(str))), ), ) self.__tag_exists_patterns = [ FunctionCall( function_name=String("notEquals"), parameters=( Or( [ mapping_pattern, FunctionCall( function_name=String("ifNull"), parameters=(mapping_pattern, Literal(String(""))), ), ] ), Param("right_hand_side", Literal(Any(str))), ), ), FunctionCall( function_name=String("has"), parameters=( ColumnMatcher( Param(TABLE_MAPPING_PARAM, AnyOptionalString()), Param(VALUE_COL_MAPPING_PARAM, String(f"{column_name}.key")), ), Literal(Param(KEY_MAPPING_PARAM, Any(str))), ), ), ]
def _get_date_range(query: Query) -> Optional[int]: """ Best guess to find the time range for the query. We pick the first column that is compared with a datetime Literal. """ pattern = FunctionCall( Or([String(ConditionFunctions.GT), String(ConditionFunctions.GTE)]), (Column(None, Param("col_name", Any(str))), Literal(Any(datetime))), ) condition = query.get_condition_from_ast() if condition is None: return None for exp in condition: result = pattern.match(exp) if result is not None: from_date, to_date = get_time_range(query, result.string("col_name")) if from_date is None or to_date is None: return None else: return (to_date - from_date).days return None
def get_time_range_expressions( conditions: Sequence[Expression], timestamp_field: str, table_name: Optional[str] = None, ) -> Tuple[Optional[Tuple[datetime, FunctionCallExpr]], Optional[Tuple[ datetime, FunctionCallExpr]], ]: max_lower_bound: Optional[Tuple[datetime, FunctionCallExpr]] = None min_upper_bound: Optional[Tuple[datetime, FunctionCallExpr]] = None table_match = String(table_name) if table_name else None for c in conditions: match = FunctionCall( Param( "operator", Or([ String(OPERATOR_TO_FUNCTION[">="]), String(OPERATOR_TO_FUNCTION["<"]), ]), ), ( Column(table_match, String(timestamp_field)), Literal(Param("timestamp", Any(datetime))), ), ).match(c) if match is not None: timestamp = cast(datetime, match.scalar("timestamp")) assert isinstance(c, FunctionCallExpr) if match.string("operator") == OPERATOR_TO_FUNCTION[">="]: if not max_lower_bound or timestamp > max_lower_bound[0]: max_lower_bound = (timestamp, c) else: if not min_upper_bound or timestamp < min_upper_bound[0]: min_upper_bound = (timestamp, c) return (max_lower_bound, min_upper_bound)
def get_time_range_estimate( query: ProcessableQuery[Table], ) -> Tuple[Optional[datetime], Optional[datetime]]: """ Best guess to find the time range for the query. We pick the first column that is compared with a datetime Literal. """ pattern = FunctionCall( Or([String(ConditionFunctions.GT), String(ConditionFunctions.GTE)]), (Column(None, Param("col_name", Any(str))), Literal(Any(datetime))), ) from_date, to_date = None, None condition = query.get_condition() if condition is None: return None, None for exp in condition: result = pattern.match(exp) if result is not None: from_date, to_date = get_time_range(query, result.string("col_name")) break return from_date, to_date
def __init__(self, columns: Set[str]): self.columns = columns column_match = Or([String(col) for col in columns]) literal = Param("literal", LiteralMatch(AnyMatch(str))) operator = Param( "operator", Or( [ String(op) for op in FUNCTION_TO_OPERATOR if op not in (ConditionFunctions.IN, ConditionFunctions.NOT_IN) ] ), ) in_operators = Param( "operator", Or((String(ConditionFunctions.IN), String(ConditionFunctions.NOT_IN))), ) col = Param("col", ColumnMatch(None, column_match)) self.__condition_matcher = Or( [ FunctionCallMatch(operator, (literal, col)), FunctionCallMatch(operator, (col, literal)), FunctionCallMatch(Param("operator", String("has")), (col, literal)), ] ) self.__in_condition_matcher = FunctionCallMatch( in_operators, ( col, Param( "tuple", FunctionCallMatch(String("tuple"), all_parameters=LiteralMatch()), ), ), )
def extract_granularity_from_query(query: Query, column: str) -> Optional[int]: """ This extracts the `granularity` from the `groupby` statement of the query. The matches are essentially the reverse of `TimeSeriesProcessor.__group_time_function`. """ groupby = query.get_groupby() column_match = ColumnMatch(None, String(column)) fn_match = FunctionCallMatch( Param( "time_fn", Or( [ String("toStartOfHour"), String("toStartOfMinute"), String("toStartOfDay"), String("toDate"), ] ), ), (column_match,), with_optionals=True, ) expr_match = FunctionCallMatch( String("toDateTime"), ( FunctionCallMatch( String("multiply"), ( FunctionCallMatch( String("intDiv"), ( FunctionCallMatch(String("toUInt32"), (column_match,)), LiteralMatch(Param("granularity", Any(int))), ), ), LiteralMatch(Param("granularity", Any(int))), ), ), LiteralMatch(Any(str)), ), ) for top_expr in groupby: for expr in top_expr: result = fn_match.match(expr) if result is not None: return GRANULARITY_MAPPING[result.string("time_fn")] result = expr_match.match(expr) if result is not None: return result.integer("granularity") return None
def build_match( col: str, ops: Sequence[str], param_type: Any, alias: Optional[str] = None, key: Optional[str] = None, ) -> Or[Expression]: # The IN condition has to be checked separately since each parameter # has to be checked individually. alias_match = AnyOptionalString() if alias is None else String(alias) pattern: Union[ColumnPattern, SubscriptableReferencePattern] if key is not None: pattern = SubscriptableReferencePattern(table_name=alias_match, column_name=String(col), key=String(key)) else: pattern = ColumnPattern(table_name=alias_match, column_name=String(col)) column_match = Param("column", pattern) return Or([ FunctionCallPattern( Or([String(op) for op in ops]), (column_match, Param("rhs", LiteralPattern( AnyPattern(param_type)))), ), FunctionCallPattern( String(ConditionFunctions.IN), ( column_match, Param( "rhs", FunctionCallPattern( Or([String("array"), String("tuple")]), all_parameters=LiteralPattern(AnyPattern(param_type)), ), ), ), ), ])
def get_project_ids_in_condition( condition: Expression) -> Optional[Set[int]]: """ Extract project ids from an expression. Returns None if no project if condition is found. It returns an empty set of conflicting project_id conditions are found. """ match = FunctionCall( None, String(ConditionFunctions.EQ), ( Column(column_name=String(project_column)), Literal(value=Param("project_id", Any(int))), ), ).match(condition) if match is not None: return {match.integer("project_id")} match = is_in_condition_pattern( Column(column_name=String(project_column))).match(condition) if match is not None: projects = match.expression("tuple") assert isinstance(projects, FunctionCallExpr) return { l.value for l in projects.parameters if isinstance(l, LiteralExpr) and isinstance(l.value, int) } match = FunctionCall( None, Param( "operator", Or([String(BooleanFunctions.AND), String(BooleanFunctions.OR)]), ), (Param("lhs", AnyExpression()), Param("rhs", AnyExpression())), ).match(condition) if match is not None: lhs_projects = get_project_ids_in_condition( match.expression("lhs")) rhs_projects = get_project_ids_in_condition( match.expression("rhs")) if lhs_projects is None: return rhs_projects elif rhs_projects is None: return lhs_projects else: return (lhs_projects & rhs_projects if match.string("operator") == BooleanFunctions.AND else lhs_projects | rhs_projects) return None
def process_query(self, query: Query, request_settings: RequestSettings) -> None: matcher = FunctionCall( String("arrayElement"), ( Column( None, String("contexts.value"), ), FunctionCall( String("indexOf"), ( Column(None, String("contexts.key")), Literal( Or([ String("device.simulator"), String("device.online"), String("device.charging"), ]), ), ), ), ), ) def process_column(exp: Expression) -> Expression: match = matcher.match(exp) if match: inner = replace(exp, alias=None) return FunctionCallExpr( exp.alias, "if", ( binary_condition( ConditionFunctions.IN, inner, literals_tuple( None, [ LiteralExpr(None, "1"), LiteralExpr(None, "True") ], ), ), LiteralExpr(None, "True"), LiteralExpr(None, "False"), ), ) return exp query.transform_expressions(process_column)
def __init__( self, column_name: str, key_names: Sequence[str], val_names: Sequence[str], ): super().__init__(column_name, key_names, val_names) self.__array_join_pattern = FunctionCall( String("arrayJoin"), (Column(column_name=Param( "col", Or([String(column) for column in self.all_columns]), ), ), ), )
def build_match( col: str, ops: Sequence[str], param_type: Any ) -> Or[Expression]: # The IN condition has to be checked separately since each parameter # has to be checked individually. column_match = ColumnMatch(alias_match, StringMatch(col)) return Or( [ FunctionCallMatch( Or([StringMatch(op) for op in ops]), (column_match, LiteralMatch(AnyMatch(param_type))), ), FunctionCallMatch( StringMatch(ConditionFunctions.IN), ( column_match, FunctionCallMatch( Or([StringMatch("array"), StringMatch("tuple")]), all_parameters=LiteralMatch(AnyMatch(param_type)), ), ), ), ] )
def __set_condition_pattern(lhs: Pattern[Expression], operator: str) -> FunctionCallPattern: return FunctionCallPattern( String(operator), ( Param("lhs", lhs), Param( "sequence", Or([ FunctionCallPattern(String("tuple"), None), FunctionCallPattern(String("array"), None), ]), ), ), )
def process_query(self, query: Query, request_settings: RequestSettings) -> None: # We care only of promoted contexts, so we do not need to match # the original nested expression. matcher = FunctionCall( String("toString"), ( Column( None, Or( [ String("device_simulator"), String("device_online"), String("device_charging"), ] ), ), ), ) def replace_exp(exp: Expression) -> Expression: if matcher.match(exp) is not None: inner = replace(exp, alias=None) return FunctionCallExpr( exp.alias, "multiIf", ( binary_condition( None, ConditionFunctions.EQ, inner, Literal(None, "") ), Literal(None, ""), binary_condition( None, ConditionFunctions.IN, inner, literals_tuple( None, [Literal(None, "1"), Literal(None, "True")] ), ), Literal(None, "True"), Literal(None, "False"), ), ) return exp query.transform_expressions(replace_exp)
def get_time_range( query: Query, timestamp_field: str) -> Tuple[Optional[datetime], Optional[datetime]]: """ Finds the minimal time range for this query. Which means, it finds the >= timestamp condition with the highest datetime literal and the < timestamp condition with the smallest and returns the interval in the form of a tuple of Literals. It only looks into first level AND conditions since, if the timestamp is nested in an OR we cannot say anything on how that compares to the other timestamp conditions. """ condition_clause = query.get_condition_from_ast() if not condition_clause: return (None, None) max_lower_bound = None min_upper_bound = None for c in get_first_level_and_conditions(condition_clause): match = FunctionCall( None, Param( "operator", Or([ String(OPERATOR_TO_FUNCTION[">="]), String(OPERATOR_TO_FUNCTION["<"]), ]), ), ( Column(None, None, String(timestamp_field)), Literal(None, Param("timestamp", Any(datetime))), ), ).match(c) if match is not None: timestamp = cast(datetime, match.scalar("timestamp")) if match.string("operator") == OPERATOR_TO_FUNCTION[">="]: if not max_lower_bound or timestamp > max_lower_bound: max_lower_bound = timestamp else: if not min_upper_bound or timestamp < min_upper_bound: min_upper_bound = timestamp return (max_lower_bound, min_upper_bound)
def __init__(self, column_name: str, hash_map_name: str, killswitch: str) -> None: self.__column_name = column_name self.__hash_map_name = hash_map_name self.__killswitch = killswitch # TODO: Add the support for IN connditions. self.__optimizable_pattern = FunctionCall( function_name=String("equals"), parameters=( Or([ mapping_pattern, FunctionCall( function_name=String("ifNull"), parameters=(mapping_pattern, Literal(String(""))), ), ]), Param("right_hand_side", Literal(Any(str))), ), )
def __init__(self, array_columns: Sequence[str]): self.__array_has_pattern = FunctionCall( String("equals"), ( Param( "has", FunctionCall( String("has"), ( Column( column_name=Or( [String(column) for column in array_columns] ) ), Literal(Any(str)), ), ), ), Literal(Integer(1)), ), )
def process_query(self, query: Query, request_settings: RequestSettings) -> None: arrayjoin_pattern = FunctionCall( String("arrayJoin"), (Column(column_name=Param( "col", Or([ String(key_column(self.__column_name)), String(val_column(self.__column_name)), ]), ), ), ), ) arrayjoins_in_query = set() for e in query.get_all_expressions(): match = arrayjoin_pattern.match(e) if match is not None: arrayjoins_in_query.add(match.string("col")) filtered_keys = [ LiteralExpr(None, key) for key in get_filtered_mapping_keys(query, self.__column_name) ] # Ensures the alias we apply to the arrayJoin is not already taken. used_aliases = {exp.alias for exp in query.get_all_expressions()} pair_alias_root = f"snuba_all_{self.__column_name}" pair_alias = pair_alias_root index = 0 while pair_alias in used_aliases: index += 1 pair_alias = f"{pair_alias_root}_{index}" def replace_expression(expr: Expression) -> Expression: """ Applies the appropriate optimization on a single arrayJoin expression. """ match = arrayjoin_pattern.match(expr) if match is None: return expr if arrayjoins_in_query == { key_column(self.__column_name), val_column(self.__column_name), }: # Both arrayJoin(col.key) and arrayJoin(col.value) expressions # present int the query. Do the arrayJoin on key-value pairs # instead of independent arrayjoin for keys and values. array_index = (LiteralExpr( None, 1) if match.string("col") == key_column( self.__column_name) else LiteralExpr(None, 2)) if not filtered_keys: return _unfiltered_mapping_pairs(expr.alias, self.__column_name, pair_alias, array_index) else: return _filtered_mapping_pairs( expr.alias, self.__column_name, pair_alias, filtered_keys, array_index, ) elif filtered_keys: # Only one between arrayJoin(col.key) and arrayJoin(col.value) # is present, and it is arrayJoin(col.key) since we found # filtered keys. return _filtered_mapping_keys(expr.alias, self.__column_name, filtered_keys) else: # No viable optimization return expr query.transform_expressions(replace_expression)
# since SnQL will require an entity to always be specified by the user. def select_entity(self, query: Query) -> EntityKey: selected_entity = match_query_to_entity(query, EVENTS_COLUMNS, TRANSACTIONS_COLUMNS) track_bad_query(query, selected_entity, EVENTS_COLUMNS, TRANSACTIONS_COLUMNS) return selected_entity metrics = MetricsWrapper(environment.metrics, "api.discover") logger = logging.getLogger(__name__) EVENT_CONDITION = FunctionCallMatch( Param("function", Or([StringMatch(op) for op in BINARY_OPERATORS])), ( Or([ColumnMatch(None, StringMatch("type")), LiteralMatch(None)]), Param("event_type", Or([ColumnMatch(), LiteralMatch()])), ), ) TRANSACTION_FUNCTIONS = FunctionCallMatch( Or([StringMatch("apdex"), StringMatch("failure_rate")]), None) EVENT_FUNCTIONS = FunctionCallMatch( Or([StringMatch("isHandled"), StringMatch("notHandled")]), None)
None, ), ( "Does not match any Column", FunctionCall(None, (Param("p1", Any(ColumnExpr)), )), FunctionCallExpr( "irrelevant", "irrelevant", (LiteralExpr(None, "str"), ), ), None, ), ( "Union of two patterns - match", Or([ Param("option1", Column(None, String("col_name"))), Param("option2", Column(None, String("other_col_name"))), ]), ColumnExpr(None, None, "other_col_name"), MatchResult({"option2": ColumnExpr(None, None, "other_col_name")}), ), ( "Union of two patterns - no match", Or([ Param("option1", Column(None, String("col_name"))), Param("option2", Column(None, String("other_col_name"))), ]), ColumnExpr(None, None, "none_of_the_two"), None, ), ( "Or within a Param",
def __init__(self, time_group_columns: Mapping[str, str], time_parse_columns: Sequence[str]) -> None: # Column names that should be mapped to different columns. self.__time_replace_columns = time_group_columns # time_parse_columns is a list of columns that, if used in a condition, should be compared with datetimes. # The columns here might overlap with the columns that get replaced, so we have to search for transformed # columns. column_match = ColumnMatch( None, Param( "column_name", Or([String(tc) for tc in time_parse_columns]), ), ) self.condition_match = FunctionCallMatch( Or([ String(ConditionFunctions.GT), String(ConditionFunctions.GTE), String(ConditionFunctions.LT), String(ConditionFunctions.LTE), String(ConditionFunctions.EQ), String(ConditionFunctions.NEQ), ]), ( Or([ column_match, FunctionCallMatch( Or([ String("toStartOfHour"), String("toStartOfMinute"), String("toStartOfDay"), String("toDate"), ]), (column_match, ), with_optionals=True, ), FunctionCallMatch( String("toDateTime"), ( FunctionCallMatch( String("multiply"), ( FunctionCallMatch( String("intDiv"), ( FunctionCallMatch( String("toUInt32"), (column_match, ), ), LiteralMatch(Any(int)), ), ), LiteralMatch(Any(int)), ), ), LiteralMatch(Any(str)), ), ), ]), Param("literal", LiteralMatch(Any(str))), ), )
from snuba.query.subscripts import subscript_key_column_name from snuba.query.timeseries_extension import TimeSeriesExtension from snuba.request.request_settings import RequestSettings from snuba.util import parse_datetime, qualified_column from snuba.utils.metrics.wrapper import MetricsWrapper EVENTS = EntityKey.EVENTS TRANSACTIONS = EntityKey.TRANSACTIONS EVENTS_AND_TRANSACTIONS = "events_and_transactions" metrics = MetricsWrapper(environment.metrics, "api.discover") logger = logging.getLogger(__name__) EVENT_CONDITION = FunctionCallMatch( None, Param("function", Or([StringMatch(op) for op in BINARY_OPERATORS])), ( Or( [ ColumnMatch(None, None, StringMatch("type")), LiteralMatch(StringMatch("type"), None), ] ), Param("event_type", Or([ColumnMatch(), LiteralMatch()])), ), ) def match_query_to_table( query: Query, events_only_columns: ColumnSet, transactions_only_columns: ColumnSet ) -> Union[EntityKey, str]:
logger = logging.getLogger(__name__) EQ_CONDITION_PATTERN = condition_pattern( {ConditionFunctions.EQ}, ColumnPattern(None, Param("lhs", Any(str))), LiteralPattern(Any(int)), commutative=True, ) FULL_CONDITION_PATTERN = Or([ EQ_CONDITION_PATTERN, FunctionCallPattern( String(ConditionFunctions.IN), ( ColumnPattern(None, Param("lhs", Any(str))), FunctionCallPattern(Or([String("tuple"), String("array")]), None), ), ), ], ) def _check_expression(pattern: Pattern[Expression], expression: Expression, column_name: str) -> bool: match = pattern.match(expression) return match is not None and match.optional_string("lhs") == column_name class ProjectIdEnforcer(ConditionChecker): def get_id(self) -> str:
def _combine_conditions(conditions: Sequence[Expression], function: str) -> Expression: """ Combine multiple independent conditions in a single function representing an AND or an OR. This is the inverse of get_first_level_conditions. """ # TODO: Make BooleanFunctions an enum for stricter typing. assert function in (BooleanFunctions.AND, BooleanFunctions.OR) assert len(conditions) > 0 if len(conditions) == 1: return conditions[0] return binary_condition(None, function, conditions[0], _combine_conditions(conditions[1:], function)) CONDITION_MATCH = Or([ FunctionCallPattern( Or([String(op) for op in BINARY_OPERATORS]), (AnyExpression(), AnyExpression()), ), FunctionCallPattern(Or([String(op) for op in UNARY_OPERATORS]), (AnyExpression(), )), ]) def is_condition(exp: Expression) -> bool: return CONDITION_MATCH.match(exp) is not None
def __init__(self, columns: Set[str], optimize_ordering: bool = False): self.columns = columns self.optimize_ordering = optimize_ordering column_match = Or([String(col) for col in columns]) literal = Param("literal", LiteralMatch(AnyMatch(str))) ordering_operators = ( ConditionFunctions.GT, ConditionFunctions.GTE, ConditionFunctions.LT, ConditionFunctions.LTE, ) operator = Param( "operator", Or([ String(op) for op in ( ConditionFunctions.EQ, ConditionFunctions.NEQ, ConditionFunctions.IS_NULL, ConditionFunctions.IS_NOT_NULL, *(ordering_operators if self.optimize_ordering else ()), ) ]), ) unoptimizable_operator = Param( "operator", Or([ String(op) for op in ( ConditionFunctions.LIKE, ConditionFunctions.NOT_LIKE, *(() if self.optimize_ordering else ordering_operators), ) ]), ) in_operators = Param( "operator", Or((String(ConditionFunctions.IN), String(ConditionFunctions.NOT_IN))), ) col = Param("col", ColumnMatch(None, column_match)) self.__condition_matcher = Or([ FunctionCallMatch(operator, (literal, col)), FunctionCallMatch(operator, (col, literal)), FunctionCallMatch(Param("operator", String("has")), (col, literal)), ]) self.__in_condition_matcher = FunctionCallMatch( in_operators, ( col, Param( "tuple", FunctionCallMatch(String("tuple"), all_parameters=LiteralMatch()), ), ), ) self.__unoptimizable_condition_matcher = Or([ FunctionCallMatch(unoptimizable_operator, (literal, col)), FunctionCallMatch(unoptimizable_operator, (col, literal)), ])
class BaseTypeConverter(QueryProcessor, ABC): def __init__(self, columns: Set[str], optimize_ordering: bool = False): self.columns = columns self.optimize_ordering = optimize_ordering column_match = Or([String(col) for col in columns]) literal = Param("literal", LiteralMatch(AnyMatch(str))) ordering_operators = ( ConditionFunctions.GT, ConditionFunctions.GTE, ConditionFunctions.LT, ConditionFunctions.LTE, ) operator = Param( "operator", Or([ String(op) for op in ( ConditionFunctions.EQ, ConditionFunctions.NEQ, ConditionFunctions.IS_NULL, ConditionFunctions.IS_NOT_NULL, *(ordering_operators if self.optimize_ordering else ()), ) ]), ) unoptimizable_operator = Param( "operator", Or([ String(op) for op in ( ConditionFunctions.LIKE, ConditionFunctions.NOT_LIKE, *(() if self.optimize_ordering else ordering_operators), ) ]), ) in_operators = Param( "operator", Or((String(ConditionFunctions.IN), String(ConditionFunctions.NOT_IN))), ) col = Param("col", ColumnMatch(None, column_match)) self.__condition_matcher = Or([ FunctionCallMatch(operator, (literal, col)), FunctionCallMatch(operator, (col, literal)), FunctionCallMatch(Param("operator", String("has")), (col, literal)), ]) self.__in_condition_matcher = FunctionCallMatch( in_operators, ( col, Param( "tuple", FunctionCallMatch(String("tuple"), all_parameters=LiteralMatch()), ), ), ) self.__unoptimizable_condition_matcher = Or([ FunctionCallMatch(unoptimizable_operator, (literal, col)), FunctionCallMatch(unoptimizable_operator, (col, literal)), ]) def process_query(self, query: Query, query_settings: QuerySettings) -> None: query.transform_expressions(self._process_expressions, skip_transform_condition=True) condition = query.get_condition() if condition is not None: if self.__contains_unoptimizable_condition(condition): processed = condition.transform(self._process_expressions) else: processed = condition.transform( self.__process_optimizable_condition) if condition == processed: processed = processed.transform(self._process_expressions) query.set_ast_condition(processed) def __strip_column_alias(self, exp: Expression) -> Expression: assert isinstance(exp, Column) return Column(alias=None, table_name=exp.table_name, column_name=exp.column_name) def __contains_unoptimizable_condition(self, exp: Expression) -> bool: """ Returns true if there is an unoptimizable condition, otherwise false. """ for e in exp: match = self.__unoptimizable_condition_matcher.match(e) if match is not None: return True return False def __process_optimizable_condition(self, exp: Expression) -> Expression: def assert_literal(lit: Expression) -> Literal: assert isinstance(lit, Literal) return lit match = self.__condition_matcher.match(exp) if match is not None: return FunctionCall( exp.alias, match.string("operator"), ( self.__strip_column_alias(match.expression("col")), self._translate_literal( assert_literal(match.expression("literal"))), ), ) in_condition_match = self.__in_condition_matcher.match(exp) if in_condition_match is not None: tuple_func = in_condition_match.expression("tuple") assert isinstance(tuple_func, FunctionCall) params = tuple_func.parameters for param in params: assert isinstance(param, Literal) new_tuple_func = FunctionCall( tuple_func.alias, tuple_func.function_name, parameters=tuple([ self._translate_literal(assert_literal(lit)) for lit in tuple_func.parameters ]), ) return FunctionCall( exp.alias, in_condition_match.string("operator"), ( self.__strip_column_alias( in_condition_match.expression("col")), new_tuple_func, ), ) return exp @abstractmethod def _translate_literal(self, exp: Literal) -> Expression: raise NotImplementedError @abstractmethod def _process_expressions(self, exp: Expression) -> Expression: raise NotImplementedError
def __post_init__(self) -> None: self.function_match = FunctionCallMatch( Or([StringMatch(func) for func in self.function_names]))
def get_first_level_or_conditions( condition: Expression) -> Sequence[Expression]: return _get_first_level_conditions(condition, BooleanFunctions.OR) TOP_LEVEL_CONDITIONS = { func_name: Or([ FunctionCallPattern( String(func_name), (Param("left", AnyExpression()), Param("right", AnyExpression())), ), FunctionCallPattern( String("equals"), ( FunctionCallPattern( String(func_name), ( Param("left", AnyExpression()), Param("right", AnyExpression()), ), ), LiteralPattern(Integer(1)), ), ), ]) for func_name in [BooleanFunctions.AND, BooleanFunctions.OR] } def _get_first_level_conditions(condition: Expression, function: str) -> Sequence[Expression]: