Beispiel #1
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Files must end with a single trailing newline.

        We only care about the segment and the siblings which come after it
        for this rule, we discard the others into the kwargs argument.

        """
        # We only care about the final segment of the parse tree.
        if not self.is_final_segment(context):
            return None

        # Include current segment for complete stack and reverse.
        parent_stack: Segments = context.functional.parent_stack
        complete_stack: Segments = (context.functional.raw_stack +
                                    context.functional.segment)
        reversed_complete_stack = complete_stack.reversed()

        # Find the trailing newline segments.
        trailing_newlines = reversed_complete_stack.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_type("dedent")),
        )
        trailing_literal_newlines = trailing_newlines.select(
            loop_while=lambda seg: sp.templated_slices(
                seg, context.templated_file).all(tsp.is_slice_type("literal")))

        if not trailing_literal_newlines:
            # We make an edit to create this segment after the child of the FileSegment.
            if len(parent_stack) == 1:
                fix_anchor_segment = context.segment
            else:
                fix_anchor_segment = parent_stack[1]

            return LintResult(
                anchor=context.segment,
                fixes=[
                    LintFix.create_after(
                        fix_anchor_segment,
                        [NewlineSegment()],
                    )
                ],
            )
        elif len(trailing_literal_newlines) > 1:
            # Delete extra newlines.
            return LintResult(
                anchor=context.segment,
                fixes=[
                    LintFix.delete(d) for d in trailing_literal_newlines[1:]
                ],
            )
        else:
            # Single newline, no need for fix.
            return None
Beispiel #2
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Outermost query should produce known number of columns."""
        start_types = [
            "select_statement", "set_expression", "with_compound_statement"
        ]
        if context.segment.is_type(
                *start_types) and not context.functional.parent_stack.any(
                    sp.is_type(*start_types)):
            crawler = SelectCrawler(context.segment, context.dialect)

            # Begin analysis at the outer query.
            if crawler.query_tree:
                try:
                    return self._analyze_result_columns(crawler.query_tree)
                except RuleFailure as e:
                    return LintResult(anchor=e.anchor)
        return None
Beispiel #3
0
    def select_info(self):
        """Returns SelectStatementColumnsAndTables on the SELECT."""
        if self.selectable.is_type("select_statement"):
            return get_select_statement_info(
                self.selectable, self.dialect, early_exit=False
            )
        else:  # values_clause
            # This is a bit dodgy, but a very useful abstraction. Here, we
            # interpret a values_clause segment as if it were a SELECT. Someday,
            # we may need to tweak this, e.g. perhaps add a separate QueryType
            # for this (depending on the needs of the rules that use it.
            #
            # For more info on the syntax and behavior of VALUES and its
            # similarity to a SELECT statement with literal values (no table
            # source), see the "Examples" section of the Postgres docs page:
            # (https://www.postgresql.org/docs/8.2/sql-values.html).
            values = Segments(self.selectable)
            alias_expression = values.children().first(sp.is_type("alias_expression"))
            name = alias_expression.children().first(
                sp.is_name("naked_identifier", "quoted_identifier")
            )
            alias_info = AliasInfo(
                name[0].raw if name else "",
                name[0] if name else None,
                bool(name),
                self.selectable,
                alias_expression[0] if alias_expression else None,
                None,
            )

            return SelectStatementColumnsAndTables(
                select_statement=self.selectable,
                table_aliases=[alias_info],
                standalone_aliases=[],
                reference_buffer=[],
                select_targets=[],
                col_aliases=[],
                using_cols=[],
            )
Beispiel #4
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Use ``!=`` instead of ``<>`` for "not equal to" comparison."""
        # Only care about not_equal_to segments.
        if context.segment.name != "not_equal_to":
            return None

        # Get the comparison operator children
        raw_comparison_operators = context.functional.segment.children(
        ).select(select_if=sp.is_type("raw_comparison_operator"))

        # Only care about ``<>``
        if [r.raw for r in raw_comparison_operators] != ["<", ">"]:
            return None

        # Provide a fix and replace ``<>`` with ``!=``
        # As each symbol is a separate symbol this is done in two steps:
        # 1. Replace < with !
        # 2. Replace > with =
        fixes = [
            LintFix.replace(
                raw_comparison_operators[0],
                [
                    CodeSegment(raw="!",
                                name="raw_not",
                                type="raw_comparison_operator")
                ],
            ),
            LintFix.replace(
                raw_comparison_operators[1],
                [
                    CodeSegment(raw="=",
                                name="raw_equals",
                                type="raw_comparison_operator")
                ],
            ),
        ]

        return LintResult(context.segment, fixes)
Beispiel #5
0
    def _eval(self, context: RuleContext) -> EvalResultType:
        # Config type hints
        self.force_enable: bool

        if (
            context.dialect.name in ["bigquery", "hive", "redshift", "spark3"]
            and not self.force_enable
        ):
            return LintResult()

        violations: List[LintResult] = []
        start_types = ["select_statement", "delete_statement", "update_statement"]
        if context.segment.is_type(
            *start_types
        ) and not context.functional.parent_stack.any(sp.is_type(*start_types)):
            dml_target_table: Optional[Tuple[str, ...]] = None
            if not context.segment.is_type("select_statement"):
                # Extract first table reference. This will be the target
                # table in a DELETE or UPDATE statement.
                table_reference = next(
                    context.segment.recursive_crawl("table_reference"), None
                )
                if table_reference:
                    dml_target_table = self._table_ref_as_tuple(table_reference)

            # Verify table references in any SELECT statements found in or
            # below context.segment in the parser tree.
            crawler = SelectCrawler(
                context.segment, context.dialect, query_class=L026Query
            )
            query: L026Query = cast(L026Query, crawler.query_tree)
            if query:
                self._analyze_table_references(
                    query, dml_target_table, context.dialect, violations
                )
        return violations or None
Beispiel #6
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Unnecessary CASE statement."""
        # Look for CASE expression.
        if (context.segment.is_type("case_expression")
                and context.segment.segments[0].name == "case"):
            # Find all 'WHEN' clauses and the optional 'ELSE' clause.
            children = context.functional.segment.children()
            when_clauses = children.select(sp.is_type("when_clause"))
            else_clauses = children.select(sp.is_type("else_clause"))

            # Can't fix if multiple WHEN clauses.
            if len(when_clauses) > 1:
                return None

            # Find condition and then expressions.
            condition_expression = when_clauses.children(
                sp.is_type("expression"))[0]
            then_expression = when_clauses.children(
                sp.is_type("expression"))[1]

            # Method 1: Check if THEN/ELSE expressions are both Boolean and can
            # therefore be reduced.
            if else_clauses:
                else_expression = else_clauses.children(
                    sp.is_type("expression"))[0]
                upper_bools = ["TRUE", "FALSE"]
                if ((then_expression.raw_upper in upper_bools)
                        and (else_expression.raw_upper in upper_bools) and
                    (then_expression.raw_upper != else_expression.raw_upper)):
                    coalesce_arg_1 = condition_expression
                    coalesce_arg_2 = KeywordSegment("false")
                    preceding_not = then_expression.raw_upper == "FALSE"

                    fixes = self._coalesce_fix_list(
                        context,
                        coalesce_arg_1,
                        coalesce_arg_2,
                        preceding_not,
                    )

                    return LintResult(
                        anchor=condition_expression,
                        fixes=fixes,
                        description="Unnecessary CASE statement. "
                        "Use COALESCE function instead.",
                    )

            # Method 2: Check if the condition expression is comparing a column
            # reference to NULL and whether that column reference is also in either the
            # THEN/ELSE expression. We can only apply this method when there is only
            # one condition in the condition expression.
            condition_expression_segments_raw = {
                segment.raw_upper
                for segment in condition_expression.segments
            }
            if {"IS", "NULL"}.issubset(condition_expression_segments_raw) and (
                    not condition_expression_segments_raw.intersection(
                        {"AND", "OR"})):
                # Check if the comparison is to NULL or NOT NULL.
                is_not_prefix = "NOT" in condition_expression_segments_raw

                # Locate column reference in condition expression.
                column_reference_segment = (
                    Segments(condition_expression).children(
                        sp.is_type("column_reference")).get())

                # Return None if none found (this condition does not apply to functions)
                if not column_reference_segment:
                    return None

                if else_clauses:
                    else_expression = else_clauses.children(
                        sp.is_type("expression"))[0]
                    # Check if we can reduce the CASE expression to a single coalesce
                    # function.
                    if (not is_not_prefix
                            and column_reference_segment.raw_upper
                            == else_expression.raw_upper):
                        coalesce_arg_1 = else_expression
                        coalesce_arg_2 = then_expression
                    elif (is_not_prefix and column_reference_segment.raw_upper
                          == then_expression.raw_upper):
                        coalesce_arg_1 = then_expression
                        coalesce_arg_2 = else_expression
                    else:
                        return None

                    if coalesce_arg_2.raw_upper == "NULL":
                        # Can just specify the column on it's own
                        # rather than using a COALESCE function.
                        return LintResult(
                            anchor=condition_expression,
                            fixes=self._column_only_fix_list(
                                context,
                                column_reference_segment,
                            ),
                            description="Unnecessary CASE statement. "
                            f"Just use column '{column_reference_segment.raw}'.",
                        )

                    return LintResult(
                        anchor=condition_expression,
                        fixes=self._coalesce_fix_list(
                            context,
                            coalesce_arg_1,
                            coalesce_arg_2,
                        ),
                        description="Unnecessary CASE statement. "
                        "Use COALESCE function instead.",
                    )
                elif (column_reference_segment.raw_segments_upper ==
                      then_expression.raw_segments_upper):
                    # Can just specify the column on it's own
                    # rather than using a COALESCE function.
                    # In this case no ELSE statement is equivalent to ELSE NULL.
                    return LintResult(
                        anchor=condition_expression,
                        fixes=self._column_only_fix_list(
                            context,
                            column_reference_segment,
                        ),
                        description="Unnecessary CASE statement. "
                        f"Just use column '{column_reference_segment.raw}'.",
                    )

        return None
Beispiel #7
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Select clause modifiers must appear on same line as SELECT."""
        # We only care about select_clause.
        if not context.segment.is_type("select_clause"):
            return None

        # Get children of select_clause and the corresponding select keyword.
        child_segments = context.functional.segment.children()
        select_keyword = child_segments[0]

        # See if we have a select_clause_modifier.
        select_clause_modifier_seg = child_segments.first(
            sp.is_type("select_clause_modifier"))

        # Rule doesn't apply if there's no select clause modifier.
        if not select_clause_modifier_seg:
            return None

        select_clause_modifier = select_clause_modifier_seg[0]

        # Are there any newlines between the select keyword
        # and the select clause modifier.
        leading_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_keyword,
        )

        # Rule doesn't apply if select clause modifier
        # is already on the same line as the select keyword.
        if not leading_newline_segments:
            return None

        # We should also check if the following select clause element
        # is on the same line as the select clause modifier.
        trailing_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_clause_modifier,
        )

        # We will insert these segments directly after the select keyword.
        edit_segments = [
            WhitespaceSegment(),
            select_clause_modifier,
        ]
        if not trailing_newline_segments:
            # if the first select clause element is on the same line
            # as the select clause modifier then also insert a newline.
            edit_segments.append(NewlineSegment())

        fixes = []
        # Move select clause modifier after select keyword.
        fixes.append(
            LintFix.create_after(
                anchor_segment=select_keyword,
                edit_segments=edit_segments,
            ))
        # Delete original newlines between select keyword and select clause modifier
        # and delete the original select clause modifier.
        fixes.extend((LintFix.delete(s) for s in leading_newline_segments))
        fixes.append(LintFix.delete(select_clause_modifier))

        # If there is whitespace (on the same line) after the select clause modifier
        # then also delete this.
        trailing_whitespace_segments = child_segments.select(
            select_if=sp.is_whitespace(),
            loop_while=sp.or_(sp.is_type("whitespace"), sp.is_meta()),
            start_seg=select_clause_modifier,
        )
        if trailing_whitespace_segments:
            fixes.extend(
                (LintFix.delete(s) for s in trailing_whitespace_segments))

        return LintResult(
            anchor=context.segment,
            fixes=fixes,
        )