コード例 #1
0
def validate_query(query: Query) -> None:
    """
    Applies all the expression validators in one pass over the AST.
    """

    for exp in query.get_all_expressions():
        for v in validators:
            v.validate(exp, query.get_from_clause())
コード例 #2
0
ファイル: tracing.py プロジェクト: isabella232/snuba
def _format_query_body(query: Query) -> Mapping[str, Any]:
    expression_formatter = TracingExpressionFormatter()
    formatted = {
        "SELECT": [[e.name, e.expression.accept(expression_formatter)]
                   for e in query.get_selected_columns_from_ast()],
        "GROUPBY":
        [e.accept(expression_formatter) for e in query.get_groupby_from_ast()],
        "ORDERBY": [[e.expression.accept(expression_formatter), e.direction]
                    for e in query.get_orderby_from_ast()],
    }
    array_join = query.get_arrayjoin_from_ast()
    if array_join:
        formatted["ARRAYJOIN"] = array_join.accept(expression_formatter)
    condition = query.get_condition_from_ast()
    if condition:
        formatted["WHERE"] = condition.accept(expression_formatter)
    having = query.get_having_from_ast()
    if having:
        formatted["HAVING"] = having.accept(expression_formatter)
    limitby = query.get_limitby()
    if limitby:
        formatted["LIMITBY"] = {
            "LIMIT": limitby.limit,
            "BY": limitby.expression.accept(expression_formatter),
        }
    limit = query.get_limit()
    if limit:
        formatted["LIMIT"] = limit
    offset = query.get_offset()
    if offset:
        formatted["OFFSET"] = offset
    return formatted
コード例 #3
0
def _format_groupby(query: AbstractQuery,
                    formatter: ExpressionVisitor[str]) -> Optional[StringNode]:
    group_clause: Optional[StringNode] = None
    ast_groupby = query.get_groupby()
    if ast_groupby:
        groupby_expressions = [e.accept(formatter) for e in ast_groupby]
        group_clause_str = f"{', '.join(groupby_expressions)}"
        if query.has_totals():
            group_clause_str = f"{group_clause_str} WITH TOTALS"
        group_clause = StringNode(f"GROUP BY {group_clause_str}")
    return group_clause
コード例 #4
0
def _format_select(query: AbstractQuery,
                   formatter: ClickhouseExpressionFormatter) -> StringNode:
    selected_cols = [
        e.expression.accept(formatter)
        for e in query.get_selected_columns_from_ast()
    ]
    return StringNode(f"SELECT {', '.join(selected_cols)}")
コード例 #5
0
ファイル: validators.py プロジェクト: getsentry/snuba
 def validate(self, query: Query, alias: Optional[str] = None) -> None:
     granularity = query.get_granularity()
     if granularity is None:
         if self.required:
             raise InvalidQueryException("Granularity is missing")
     elif granularity < self.minimum or (granularity % self.minimum) != 0:
         raise InvalidQueryException(
             f"granularity must be multiple of {self.minimum}")
コード例 #6
0
def _format_granularity(
    query: AbstractQuery, formatter: ExpressionVisitor[str]
) -> Optional[StringNode]:
    ast_granularity = query.get_granularity()
    return (
        StringNode(f"GRANULARITY {ast_granularity}")
        if ast_granularity is not None
        else None
    )
コード例 #7
0
def _format_arrayjoin(
        query: AbstractQuery,
        formatter: ExpressionVisitor[str]) -> Optional[StringNode]:
    array_join = query.get_arrayjoin()
    if array_join is not None:
        column_likes_joined = [el.accept(formatter) for el in array_join]
        return StringNode("ARRAY JOIN {}".format(
            ",".join(column_likes_joined)))

    return None
コード例 #8
0
def _format_limitby(
        query: AbstractQuery,
        formatter: ClickhouseExpressionFormatterBase) -> Optional[StringNode]:
    ast_limitby = query.get_limitby()

    if ast_limitby is not None:
        return StringNode("LIMIT {} BY {}".format(
            ast_limitby.limit, ast_limitby.expression.accept(formatter)))

    return None
コード例 #9
0
def _format_orderby(query: AbstractQuery,
                    formatter: ExpressionVisitor[str]) -> Optional[StringNode]:
    ast_orderby = query.get_orderby()
    if ast_orderby:
        orderby = [
            f"{e.expression.accept(formatter)} {e.direction.value}"
            for e in ast_orderby
        ]
        return StringNode(f"ORDER BY {', '.join(orderby)}")
    else:
        return None
コード例 #10
0
    def validate(self, query: Query, alias: Optional[str] = None) -> None:
        selected = query.get_selected_columns()
        if len(selected) != 1:
            raise InvalidQueryException(
                "only one aggregation in the select allowed")

        disallowed = ["groupby", "having", "orderby"]
        for field in disallowed:
            if getattr(query, f"get_{field}")():
                raise InvalidQueryException(
                    f"invalid clause {field} in subscription query")
コード例 #11
0
 def validate(self, query: Query, alias: Optional[str] = None) -> None:
     condition = query.get_condition()
     top_level = get_first_level_and_conditions(
         condition) if condition else []
     for cond in top_level:
         if self.match.match(cond):
             raise InvalidExpressionException(
                 cond,
                 f"Cannot have existing conditions on time field {self.required_time_column}",
                 report=False,
             )
コード例 #12
0
def _format_limitby(query: AbstractQuery,
                    formatter: ExpressionVisitor[str]) -> Optional[StringNode]:
    ast_limitby = query.get_limitby()

    if ast_limitby is not None:
        columns_accepted = [
            column.accept(formatter) for column in ast_limitby.columns
        ]
        return StringNode("LIMIT {} BY {}".format(ast_limitby.limit,
                                                  ",".join(columns_accepted)))

    return None
コード例 #13
0
    def _replace_time_condition(
        self,
        query: Query,
        from_date: datetime,
        from_exp: FunctionCall,
        to_date: datetime,
        to_exp: FunctionCall,
    ) -> None:
        max_days, date_align = state.get_configs(
            [("max_days", None), ("date_align_seconds", 1)]
        )

        def align_fn(dt: datetime) -> datetime:
            assert isinstance(date_align, int)
            return dt - timedelta(seconds=(dt - dt.min).seconds % date_align)

        from_date, to_date = align_fn(from_date), align_fn(to_date)
        assert from_date <= to_date

        if max_days is not None and (to_date - from_date).days > max_days:
            from_date = to_date - timedelta(days=max_days)

        def replace_cond(exp: Expression) -> Expression:
            if not isinstance(exp, FunctionCall):
                return exp
            elif exp == from_exp:
                return replace(
                    exp, parameters=(from_exp.parameters[0], Literal(None, from_date)),
                )
            elif exp == to_exp:
                return replace(
                    exp, parameters=(to_exp.parameters[0], Literal(None, to_date))
                )

            return exp

        condition = query.get_condition_from_ast()
        top_level = get_first_level_and_conditions(condition) if condition else []
        new_top_level = list(map(replace_cond, top_level))
        query.set_ast_condition(combine_and_conditions(new_top_level))
コード例 #14
0
ファイル: validators.py プロジェクト: getsentry/snuba
    def _validate_groupby_fields_have_matching_conditions(
            query: Query, alias: Optional[str] = None) -> None:
        """
        Method that insures that for every field in the group by clause, there should be a
        matching a condition. For example, if we had in our groupby clause [project_id, tags[3]],
        we should have the following conditions in the where clause `project_id = 3 AND tags[3]
        IN array(1,2,3)`. This is necessary because we want to avoid the case where an
        unspecified number of buckets is returned.
        """
        condition = query.get_condition()
        top_level = get_first_level_and_conditions(
            condition) if condition else []
        for exp in query.get_groupby():
            key: Optional[str] = None
            if isinstance(exp, SubscriptableReferenceExpr):
                column_name = str(exp.column.column_name)
                key = str(exp.key.value)
            elif isinstance(exp, Column):
                column_name = exp.column_name
            else:
                raise InvalidQueryException(
                    "Unhandled column type in group by validation")

            match = build_match(
                col=column_name,
                ops=[ConditionFunctions.EQ],
                param_type=int,
                alias=alias,
                key=key,
            )
            found = any(match.match(cond) for cond in top_level)

            if not found:
                raise InvalidQueryException(
                    f"Every field in groupby must have a corresponding condition in "
                    f"where clause. missing condition for field {exp}")
コード例 #15
0
    def validate(self, query: Query, alias: Optional[str] = None) -> None:
        condition = query.get_condition()
        top_level = get_first_level_and_conditions(
            condition) if condition else []

        missing = set()
        if self.required_columns:
            for col in self.required_columns:
                match = build_match(col, [ConditionFunctions.EQ], int, alias)
                found = any(match.match(cond) for cond in top_level)
                if not found:
                    missing.add(col)

        if missing:
            raise InvalidQueryException(
                f"missing required conditions for {', '.join(missing)}")
コード例 #16
0
ファイル: validators.py プロジェクト: getsentry/snuba
    def validate(self, query: Query, alias: Optional[str] = None) -> None:
        if self.validation_mode == ColumnValidationMode.DO_NOTHING:
            return

        query_columns = query.get_all_ast_referenced_columns()

        missing = set()
        for column in query_columns:
            if (column.table_name == alias
                    and column.column_name not in self.entity_data_model):
                missing.add(column.column_name)

        if missing:
            error_message = f"query column(s) {', '.join(missing)} do not exist"
            if self.validation_mode == ColumnValidationMode.ERROR:
                raise InvalidQueryException(error_message)
            elif self.validation_mode == ColumnValidationMode.WARN:
                logger.warning(error_message, exc_info=True)
コード例 #17
0
ファイル: validators.py プロジェクト: getsentry/snuba
    def validate(
        self,
        query: Query,
        alias: Optional[str] = None,
    ) -> None:
        selected = query.get_selected_columns()
        if len(selected) > self.max_allowed_aggregations:
            aggregation_error_text = (
                "1 aggregation is" if self.max_allowed_aggregations == 1 else
                f"{self.max_allowed_aggregations} aggregations are")
            raise InvalidQueryException(
                f"A maximum of {aggregation_error_text} allowed in the select")

        for field in self.disallowed_aggregations:
            if getattr(query, f"get_{field}")():
                raise InvalidQueryException(
                    f"invalid clause {field} in subscription query")

        if "groupby" not in self.disallowed_aggregations:
            self._validate_groupby_fields_have_matching_conditions(
                query, alias)
コード例 #18
0
def _format_limit(
        query: AbstractQuery,
        formatter: ClickhouseExpressionFormatterBase) -> Optional[StringNode]:
    ast_limit = query.get_limit()
    return (StringNode(f"LIMIT {ast_limit} OFFSET {query.get_offset()}")
            if ast_limit is not None else None)
コード例 #19
0
def _format_limit(query: AbstractQuery,
                  formatter: ExpressionVisitor[str]) -> Optional[StringNode]:
    ast_limit = query.get_limit()
    return (StringNode(f"LIMIT {ast_limit} OFFSET {query.get_offset()}")
            if ast_limit is not None else None)
コード例 #20
0
def get_object_ids_in_query_ast(query: AbstractQuery,
                                object_column: str) -> Optional[Set[int]]:
    """
    Finds the object ids (e.g. project ids) this query is filtering according to the AST
    query representation.

    It works like get_project_ids_in_query with the exception that
    boolean functions are supported here.
    """
    def get_object_ids_in_condition(
            condition: Expression) -> Optional[Set[int]]:
        """
        Extract project ids from an expression. Returns None if no project
        if condition is found. It returns an empty set of conflicting project_id
        conditions are found.
        """
        match = FunctionCall(
            String(ConditionFunctions.EQ),
            (
                Column(column_name=String(object_column)),
                Literal(value=Param("object_id", Any(int))),
            ),
        ).match(condition)
        if match is not None:
            return {match.integer("object_id")}

        match = is_in_condition_pattern(
            Column(column_name=String(object_column))).match(condition)
        if match is not None:
            objects = match.expression("tuple")
            assert isinstance(objects, FunctionCallExpr)
            return {
                lit.value
                for lit in objects.parameters
                if isinstance(lit, LiteralExpr) and isinstance(lit.value, int)
            }

        match = FunctionCall(
            Param(
                "operator",
                Or([String(BooleanFunctions.AND),
                    String(BooleanFunctions.OR)]),
            ),
            (Param("lhs", AnyExpression()), Param("rhs", AnyExpression())),
        ).match(condition)
        if match is not None:
            lhs_objects = get_object_ids_in_condition(match.expression("lhs"))
            rhs_objects = get_object_ids_in_condition(match.expression("rhs"))
            if lhs_objects is None:
                return rhs_objects
            elif rhs_objects is None:
                return lhs_objects
            else:
                return (lhs_objects & rhs_objects if match.string("operator")
                        == BooleanFunctions.AND else lhs_objects | rhs_objects)

        return None

    condition = query.get_condition()
    return get_object_ids_in_condition(
        condition) if condition is not None else None
コード例 #21
0
    def validate_required_conditions(
        self, query: Query, alias: Optional[str] = None
    ) -> bool:
        if not self._required_filter_columns and not self._required_time_column:
            return True

        condition = query.get_condition_from_ast()
        top_level = get_first_level_and_conditions(condition) if condition else []
        if not top_level:
            return False

        alias_match = AnyOptionalString() if alias is None else StringMatch(alias)

        def build_match(
            col: str, ops: Sequence[str], param_type: Any
        ) -> Or[Expression]:
            # The IN condition has to be checked separately since each parameter
            # has to be checked individually.
            column_match = ColumnMatch(alias_match, StringMatch(col))
            return Or(
                [
                    FunctionCallMatch(
                        Or([StringMatch(op) for op in ops]),
                        (column_match, LiteralMatch(AnyMatch(param_type))),
                    ),
                    FunctionCallMatch(
                        StringMatch(ConditionFunctions.IN),
                        (
                            column_match,
                            FunctionCallMatch(
                                Or([StringMatch("array"), StringMatch("tuple")]),
                                all_parameters=LiteralMatch(AnyMatch(param_type)),
                            ),
                        ),
                    ),
                ]
            )

        if self._required_filter_columns:
            for col in self._required_filter_columns:
                match = build_match(col, [ConditionFunctions.EQ], int)
                found = any(match.match(cond) for cond in top_level)
                if not found:
                    return False

        if self._required_time_column:
            match = build_match(
                self._required_time_column, [ConditionFunctions.EQ], datetime,
            )
            found = any(match.match(cond) for cond in top_level)
            if found:
                return True

            lower, upper = get_time_range_expressions(
                top_level, self._required_time_column, alias
            )
            if not lower or not upper:
                return False

            # At this point we have valid conditions. However we need to align them and
            # make sure they don't exceed the max_days. Replace the conditions.
            self._replace_time_condition(query, *lower, *upper)

        return True
コード例 #22
0
def _format_select(query: AbstractQuery,
                   formatter: ExpressionVisitor[str]) -> StringNode:
    selected_cols = [
        e.expression.accept(formatter) for e in query.get_selected_columns()
    ]
    return StringNode(f"SELECT {', '.join(selected_cols)}")