def test_not_handled_processor() -> None: columnset = ColumnSet([]) unprocessed = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression(name=None, expression=Column(None, None, "id")), SelectedExpression( "result", FunctionCall("result", "notHandled", tuple(),), ), ], ) expected = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression(name=None, expression=Column(None, None, "id")), SelectedExpression( "result", FunctionCall( "result", "arrayExists", ( Lambda( None, ("x",), binary_condition( BooleanFunctions.AND, FunctionCall(None, "isNotNull", (Argument(None, "x"),)), binary_condition( ConditionFunctions.EQ, FunctionCall( None, "assumeNotNull", (Argument(None, "x"),) ), Literal(None, 0), ), ), ), Column(None, None, "exception_stacks.mechanism_handled"), ), ), ), ], ) processor = handled_functions.HandledFunctionsProcessor( "exception_stacks.mechanism_handled", columnset ) processor.process_query(unprocessed, HTTPRequestSettings()) assert expected.get_selected_columns() == unprocessed.get_selected_columns() ret = unprocessed.get_selected_columns()[1].expression.accept( ClickhouseExpressionFormatter() ) assert ret == ( "(arrayExists((x -> isNotNull(x) AND equals(assumeNotNull(x), 0)), exception_stacks.mechanism_handled) AS result)" )
def test_timeseries_format_expressions( granularity: int, condition: Optional[FunctionCall], exp_column: FunctionCall, exp_condition: Optional[FunctionCall], formatted_column: str, formatted_condition: str, ) -> None: unprocessed = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression( "transaction.duration", Column("transaction.duration", None, "duration")), SelectedExpression("my_time", Column("my_time", None, "time")), ], condition=condition, groupby=[Column("my_time", None, "time")], granularity=granularity, ) expected = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression( "transaction.duration", Column("transaction.duration", None, "duration")), SelectedExpression(exp_column.alias, exp_column), ], condition=exp_condition, ) entity = TransactionsEntity() processors = entity.get_query_processors() for processor in processors: if isinstance(processor, TimeSeriesProcessor): processor.process_query(unprocessed, HTTPRequestSettings()) assert expected.get_selected_columns() == unprocessed.get_selected_columns( ) assert expected.get_condition() == unprocessed.get_condition() ret = unprocessed.get_selected_columns()[1].expression.accept( ClickhouseExpressionFormatter()) assert ret == formatted_column if condition: query_condition = unprocessed.get_condition() assert query_condition is not None ret = query_condition.accept(ClickhouseExpressionFormatter()) assert formatted_condition == ret assert extract_granularity_from_query(unprocessed, "finish_ts") == granularity
def test_functions( default_validators: Mapping[str, FunctionCallValidator], entity_validators: Mapping[str, FunctionCallValidator], exception: Optional[Type[InvalidExpressionException]], ) -> None: fn_cached = functions.default_validators functions.default_validators = default_validators entity_return = MagicMock() entity_return.return_value = entity_validators events_entity = get_entity(EntityKey.EVENTS) cached = events_entity.get_function_call_validators setattr(events_entity, "get_function_call_validators", entity_return) data_source = QueryEntity(EntityKey.EVENTS, ColumnSet([])) expression = FunctionCall( None, "f", (Column(alias=None, table_name=None, column_name="col"), )) if exception is None: FunctionCallsValidator().validate(expression, data_source) else: with pytest.raises(exception): FunctionCallsValidator().validate(expression, data_source) # TODO: This should use fixture to do this setattr(events_entity, "get_function_call_validators", cached) functions.default_validators = fn_cached
def test_outcomes_columns_validation(key: EntityKey) -> None: entity = get_entity(key) query_entity = QueryEntity(key, entity.get_data_model()) bad_query = LogicalQuery( query_entity, selected_columns=[ SelectedExpression("asdf", Column("_snuba_asdf", None, "asdf")), ], ) good_query = LogicalQuery( query_entity, selected_columns=[ SelectedExpression( column.name, Column(f"_snuba_{column.name}", None, column.name)) for column in entity.get_data_model().columns ], ) validator = EntityContainsColumnsValidator( entity.get_data_model(), validation_mode=ColumnValidationMode.ERROR) with pytest.raises(InvalidQueryException): validator.validate(bad_query) validator.validate(good_query)
def test_invalid_function_name(expression: FunctionCall, should_raise: bool) -> None: data_source = QueryEntity(EntityKey.EVENTS, ColumnSet([])) state.set_config("function-validator.enabled", True) with pytest.raises(InvalidExpressionException): FunctionCallsValidator().validate(expression, data_source)
def test_apply_quota( enabled: int, referrer: str, config_to_set: str, expected_quota: Optional[ResourceQuota], ) -> None: state.set_config(ENABLED_CONFIG, enabled) state.set_config(config_to_set, 5) query = Query( QueryEntity(EntityKey.EVENTS, EntityColumnSet([])), selected_columns=[ SelectedExpression("column2", Column(None, None, "column2")) ], condition=binary_condition( ConditionFunctions.EQ, Column("_snuba_project_id", None, "project_id"), Literal(None, 1), ), ) settings = HTTPQuerySettings() settings.referrer = referrer ResourceQuotaProcessor("project_id").process_query(query, settings) assert settings.get_resource_quota() == expected_quota
def test_failure_rate_format_expressions() -> None: unprocessed = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression(name=None, expression=Column(None, None, "column2")), SelectedExpression("perf", FunctionCall("perf", "failure_rate", ())), ], ) expected = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression(name=None, expression=Column(None, None, "column2")), SelectedExpression( "perf", divide( FunctionCall( None, "countIf", (combine_and_conditions([ binary_condition( ConditionFunctions.NEQ, Column(None, None, "transaction_status"), Literal(None, code), ) for code in [0, 1, 2] ]), ), ), count(), "perf", ), ), ], ) failure_rate_processor(ColumnSet([])).process_query( unprocessed, HTTPRequestSettings()) assert (expected.get_selected_columns_from_ast() == unprocessed.get_selected_columns_from_ast()) ret = unprocessed.get_selected_columns_from_ast()[1].expression.accept( ClickhouseExpressionFormatter()) assert ret == ( "(divide(countIf(notEquals(transaction_status, 0) AND notEquals(transaction_status, 1) AND notEquals(transaction_status, 2)), count()) AS perf)" )
def query_fn(cond: Optional[Expression]) -> LogicalQuery: return LogicalQuery( QueryEntity(key, entity.get_data_model()), selected_columns=[ SelectedExpression( "time", Column("_snuba_timestamp", None, "timestamp")), ], condition=cond, )
def test_granularity_added( entity_key: EntityKey, column: str, requested_granularity: Optional[int], query_granularity: int, ) -> None: query = Query( QueryEntity(entity_key, ColumnSet([])), selected_columns=[ SelectedExpression(column, Column(None, None, column)) ], condition=binary_condition(ConditionFunctions.EQ, Column(None, None, "metric_id"), Literal(None, 123)), granularity=(requested_granularity), ) try: GranularityProcessor().process_query(query, HTTPQuerySettings()) except InvalidGranularityException: assert query_granularity is None else: assert query == Query( QueryEntity(entity_key, ColumnSet([])), selected_columns=[ SelectedExpression(column, Column(None, None, column)) ], condition=binary_condition( BooleanFunctions.AND, binary_condition( ConditionFunctions.EQ, Column(None, None, "granularity"), Literal(None, query_granularity), ), binary_condition( ConditionFunctions.EQ, Column(None, None, "metric_id"), Literal(None, 123), ), ), granularity=(requested_granularity), )
def visit_entity_single( self, node: Node, visited_children: Tuple[Any, Any, EntityKey, Union[Optional[float], Node], Any, Any], ) -> QueryEntity: _, _, name, sample, _, _ = visited_children if isinstance(sample, Node): sample = None return QueryEntity(name, get_entity(name).get_data_model(), sample)
def parse_query(body: MutableMapping[str, Any], dataset: Dataset) -> Query: """ Parses the query body generating the AST. This only takes into account the initial query body. Extensions are parsed by extension processors and are supposed to update the AST. Parsing includes two phases. The first transforms the json body into a minimal query Object resolving expressions, conditions, etc. The second phase performs some query processing to provide a sane query to the dataset specific section. - It prevents alias shadowing. - It transforms columns from the tags[asd] form into SubscriptableReference. - Applies aliases to all columns that do not have one and that do not represent a reference to an existing alias. During query processing a column can be transformed into a different expression. It is essential to preserve the original column name so that the result set still has a column with the name provided by the user no matter on which transformation we applied. By applying aliases at this stage every processor just needs to preserve them to guarantee the correctness of the query. - Expands all the references to aliases by inlining the expression to make aliasing transparent to all query processing phases. References to aliases are reintroduced at the end of the query processing. Alias references are packaged back at the end of processing. """ # TODO: Parse the entity out of the query body and select the correct one from the dataset entity = dataset.get_default_entity() query = _parse_query_impl(body, entity) # TODO: These should support composite queries. _validate_empty_table_names(query) _validate_aliases(query) _parse_subscriptables(query) _apply_column_aliases(query) _expand_aliases(query) # WARNING: These steps above assume table resolution did not happen # yet. If it is put earlier than here (unlikely), we need to adapt them. _deescape_aliases(query) _mangle_aliases(query) _validate_arrayjoin(query) # XXX: Select the entity to be used for the query. This step is temporary. Eventually # entity selection will be moved to Sentry and specified for all SnQL queries. selected_entity = dataset.select_entity(query) query_entity = QueryEntity( selected_entity, get_entity(selected_entity).get_data_model() ) query.set_from_clause(query_entity) validate_query(query) return query
def test_project_extension_query_processing( raw_data: Mapping[str, Any], expected_conditions: Sequence[Condition], expected_ast_conditions: Expression, ) -> None: extension = ProjectExtension(project_column="project_id") valid_data = validate_jsonschema(raw_data, extension.get_schema()) query = Query({"conditions": []}, QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() extension.get_processor().process_query(query, valid_data, request_settings) assert query.get_condition_from_ast() == expected_ast_conditions
def test_allowed_functions_validator(expression: FunctionCall, should_raise: bool) -> None: data_source = QueryEntity(EntityKey.EVENTS, ColumnSet([])) state.set_config("function-validator.enabled", True) if should_raise: with pytest.raises(InvalidFunctionCall): AllowedFunctionValidator().validate(expression.function_name, expression.parameters, data_source) else: AllowedFunctionValidator().validate(expression.function_name, expression.parameters, data_source)
def test_entity_validation_failure(key: EntityKey, condition: Optional[Expression]) -> None: entity = get_entity(key) query = LogicalQuery( QueryEntity(key, entity.get_data_model()), selected_columns=[ SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")), ], condition=condition, ) assert not entity.validate_required_conditions(query)
def visit_entity_match( self, node: Node, visited_children: Tuple[Any, str, Any, Any, EntityKey, Union[Optional[float], Node], Any, Any], ) -> IndividualNode[QueryEntity]: _, alias, _, _, name, sample, _, _ = visited_children if isinstance(sample, Node): sample = None return IndividualNode( alias, QueryEntity(name, get_entity(name).get_data_model(), sample))
def test_organization_extension_query_processing_happy_path() -> None: extension = OrganizationExtension() raw_data = {"organization": 2} valid_data = validate_jsonschema(raw_data, extension.get_schema()) query = Query({"conditions": []}, QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() extension.get_processor().process_query(query, valid_data, request_settings) assert query.get_condition_from_ast() == binary_condition( ConditionFunctions.EQ, Column(None, None, "org_id"), Literal(None, 2))
def test_organization_extension_query_processing_happy_path() -> None: extension = OrganizationExtension() schema = cast(MutableMapping[str, Any], extension.get_schema()) raw_data = {"organization": 2} valid_data = validate_jsonschema(raw_data, schema) query = Query(QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() extension.get_processor().process_query(query, valid_data, request_settings) assert query.get_condition() == binary_condition( ConditionFunctions.EQ, Column("_snuba_org_id", None, "org_id"), Literal(None, 2))
def test_entity_validation(key: EntityKey, condition: Optional[Expression]) -> None: query = LogicalQuery( QueryEntity(key, get_entity(key).get_data_model()), selected_columns=[ SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")), ], condition=condition, ) validator = EntityRequiredColumnValidator({"project_id"}) validator.validate(query)
def test_no_time_based_validation(key: EntityKey, condition: Expression) -> None: entity = get_entity(key) query = LogicalQuery( QueryEntity(key, entity.get_data_model()), selected_columns=[ SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")), ], condition=condition, ) assert entity.required_time_column is not None validator = NoTimeBasedConditionValidator(entity.required_time_column) validator.validate(query)
def test_project_rate_limit_processor(unprocessed: Expression, project_id: int) -> None: query = Query( QueryEntity(EntityKey.EVENTS, EntityColumnSet([])), selected_columns=[SelectedExpression("column2", Column(None, None, "column2"))], condition=unprocessed, ) settings = HTTPQuerySettings() num_before = len(settings.get_rate_limit_params()) ProjectRateLimiterProcessor("project_id").process_query(query, settings) assert len(settings.get_rate_limit_params()) == num_before + 1 rate_limiter = settings.get_rate_limit_params()[-1] assert rate_limiter.rate_limit_name == PROJECT_RATE_LIMIT_NAME assert rate_limiter.bucket == str(project_id) assert rate_limiter.per_second_limit == 1000 assert rate_limiter.concurrent_limit == 1000
def test_handled_processor_invalid() -> None: columnset = ColumnSet([]) unprocessed = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression( "result", FunctionCall("result", "isHandled", (Column(None, None, "type"),),), ), ], ) processor = handled_functions.HandledFunctionsProcessor( "exception_stacks.mechanism_handled", columnset ) with pytest.raises(InvalidExpressionException): processor.process_query(unprocessed, HTTPRequestSettings())
def test_referrer_rate_limit_processor_no_config( unprocessed: Expression, project_id: int ) -> None: query = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[SelectedExpression("column2", Column(None, None, "column2"))], condition=unprocessed, ) # don't configure it, rate limit shouldn't fire referrer = "abusive_delivery" settings = HTTPQuerySettings() settings.referrer = referrer num_before = len(settings.get_rate_limit_params()) ProjectReferrerRateLimiter("project_id").process_query(query, settings) assert len(settings.get_rate_limit_params()) == num_before
def selector(query: Union[CompositeQuery[QueryEntity], LogicalQuery]) -> None: # If you are doing a JOIN, then you have to specify the entity if isinstance(query, CompositeQuery): return if get_dataset_name(dataset) == "discover": query_entity = query.get_from_clause() # The legacy -> snql parser will mark queries with no entity specified as the "discover" entity # so only do this selection in that case. If someone wants the "discover" entity specifically # then their query will have to only use fields from that entity. if query_entity.key == EntityKey.DISCOVER: selected_entity_key = dataset.select_entity(query) selected_entity = get_entity(selected_entity_key) query_entity = QueryEntity( selected_entity_key, selected_entity.get_data_model() ) query.set_from_clause(query_entity) # XXX: This exists only to ensure that the generated SQL matches legacy queries. def replace_time_condition_aliases(exp: Expression) -> Expression: if ( isinstance(exp, FunctionCall) and len(exp.parameters) == 2 and isinstance(exp.parameters[0], Column) and exp.parameters[0].alias == "_snuba_timestamp" ): return FunctionCall( exp.alias, exp.function_name, ( Column( f"_snuba_{selected_entity.required_time_column}", exp.parameters[0].table_name, exp.parameters[0].column_name, ), exp.parameters[1], ), ) return exp condition = query.get_condition() if condition is not None: query.set_ast_condition( condition.transform(replace_time_condition_aliases) )
def test_project_extension_project_rate_limits_are_overridden() -> None: extension = ProjectExtension(project_column="project_id") raw_data = {"project": [3, 4]} valid_data = validate_jsonschema(raw_data, extension.get_schema()) query = Query({"conditions": []}, QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() state.set_config("project_per_second_limit_3", 5) state.set_config("project_concurrent_limit_3", 10) extension.get_processor().process_query(query, valid_data, request_settings) rate_limits = request_settings.get_rate_limit_params() most_recent_rate_limit = rate_limits[-1] assert most_recent_rate_limit.bucket == "3" assert most_recent_rate_limit.per_second_limit == 5 assert most_recent_rate_limit.concurrent_limit == 10
def test_project_extension_query_adds_rate_limits() -> None: extension = ProjectExtension(project_column="project_id") raw_data = {"project": [1, 2]} valid_data = validate_jsonschema(raw_data, extension.get_schema()) query = Query({"conditions": []}, QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() num_rate_limits_before_processing = len(request_settings.get_rate_limit_params()) extension.get_processor().process_query(query, valid_data, request_settings) rate_limits = request_settings.get_rate_limit_params() # make sure a rate limit was added by the processing assert len(rate_limits) == num_rate_limits_before_processing + 1 most_recent_rate_limit = rate_limits[-1] assert most_recent_rate_limit.bucket == "1" assert most_recent_rate_limit.per_second_limit == 1000 assert most_recent_rate_limit.concurrent_limit == 1000
def test_invalid_datetime() -> None: unprocessed = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[ SelectedExpression( "transaction.duration", Column("transaction.duration", None, "duration") ), ], condition=binary_condition( ConditionFunctions.EQ, Column("my_time", None, "time"), Literal(None, ""), ), ) entity = TransactionsEntity() processors = entity.get_query_processors() for processor in processors: if isinstance(processor, TimeSeriesProcessor): with pytest.raises(InvalidQueryException): processor.process_query(unprocessed, HTTPRequestSettings())
def test_query_extension_processing( raw_data: Mapping[str, Any], expected_ast_condition: Expression, expected_granularity: int, ) -> None: state.set_config("max_days", 1) extension = TimeSeriesExtension( default_granularity=60, default_window=timedelta(days=5), timestamp_column="timestamp", ) valid_data = validate_jsonschema(raw_data, extension.get_schema()) query = Query({"conditions": []}, QueryEntity(EntityKey.EVENTS, ColumnSet([]))) request_settings = HTTPRequestSettings() extension.get_processor().process_query(query, valid_data, request_settings) assert query.get_condition_from_ast() == expected_ast_condition assert query.get_granularity() == expected_granularity
def test_referrer_specified_project(unprocessed: Expression, project_id: int) -> None: query = Query( QueryEntity(EntityKey.EVENTS, ColumnSet([])), selected_columns=[SelectedExpression("column2", Column(None, None, "column2"))], condition=unprocessed, ) state.set_config("project_referrer_per_second_limit_abusive_delivery_1", 10) state.set_config("project_referrer_concurrent_limit_abusive_delivery_1", 10) referrer = "abusive_delivery" settings = HTTPQuerySettings() settings.referrer = referrer num_before = len(settings.get_rate_limit_params()) ProjectReferrerRateLimiter("project_id").process_query(query, settings) assert len(settings.get_rate_limit_params()) == num_before + 1 rate_limiter = settings.get_rate_limit_params()[-1] assert rate_limiter.rate_limit_name == PROJECT_REFERRER_RATE_LIMIT_NAME assert rate_limiter.bucket == f"{project_id}" assert rate_limiter.per_second_limit == 10 assert rate_limiter.concurrent_limit == 10
def test_org_rate_limit_processor_overridden(unprocessed: Expression, org_id: int) -> None: query = Query( QueryEntity(EntityKey.EVENTS, EntityColumnSet([])), selected_columns=[ SelectedExpression("column2", Column(None, None, "column2")) ], condition=unprocessed, ) settings = HTTPQuerySettings() state.set_config(f"org_per_second_limit_{org_id}", 5) state.set_config(f"org_concurrent_limit_{org_id}", 10) num_before = len(settings.get_rate_limit_params()) OrganizationRateLimiterProcessor("org_id").process_query(query, settings) assert len(settings.get_rate_limit_params()) == num_before + 1 rate_limiter = settings.get_rate_limit_params()[-1] assert rate_limiter.rate_limit_name == ORGANIZATION_RATE_LIMIT_NAME assert rate_limiter.bucket == str(org_id) assert rate_limiter.per_second_limit == 5 assert rate_limiter.concurrent_limit == 10
def test_like_validator( expressions: Sequence[Expression], expected_types: Sequence[ParamType], extra_param: bool, should_raise: bool, ) -> None: entity = QueryEntity( EntityKey.EVENTS, ColumnSet([ ("event_id", String()), ("level", String(Modifiers(nullable=True))), ("str_col", String()), ("timestamp", DateTime()), ("received", DateTime(Modifiers(nullable=True))), ]), ) func_name = "like" validator = SignatureValidator(expected_types, extra_param) if should_raise: with pytest.raises(InvalidFunctionCall): validator.validate(func_name, expressions, entity) else: validator.validate(func_name, expressions, entity)