Esempio n. 1
0
    def _get_subsequent_whitespace(
        self,
        context,
    ) -> Tuple[Optional[BaseSegment], Optional[BaseSegment]]:
        """Search forwards through the raw segments for subsequent whitespace.

        Return a tuple of both the trailing whitespace segment and the
        first non-whitespace segment discovered.
        """
        # Get all raw segments. "raw_segments" is appropriate as the
        # only segments we can care about are comma, whitespace,
        # newline, and comment, which are all raw. Using the
        # raw_segments allows us to account for possible unexpected
        # parse tree structures resulting from other rule fixes.
        raw_segments = FunctionalContext(context).raw_segments
        # Start after the current comma within the list. Get all the
        # following whitespace.
        following_segments = raw_segments.select(
            loop_while=sp.or_(sp.is_meta(), sp.is_type("whitespace")),
            start_seg=context.segment,
        )
        subsequent_whitespace = following_segments.last(
            sp.is_type("whitespace"))
        try:
            return (
                subsequent_whitespace[0] if subsequent_whitespace else None,
                raw_segments[raw_segments.index(context.segment) +
                             len(following_segments) + 1],
            )
        except IndexError:
            # If we find ourselves here it's all whitespace (or nothing) to the
            # end of the file. This can only happen in bigquery (see
            # test_pass_bigquery_trailing_comma).
            return subsequent_whitespace, None
Esempio n. 2
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Trailing commas within select clause."""
        # Config type hints
        self.select_clause_trailing_comma: str

        segment = FunctionalContext(context).segment
        children = segment.children()
        # Iterate content to find last element
        last_content: BaseSegment = children.last(sp.is_code())[0]

        # What mode are we in?
        if self.select_clause_trailing_comma == "forbid":
            # Is it a comma?
            if last_content.is_type("comma"):
                return LintResult(
                    anchor=last_content,
                    fixes=[LintFix.delete(last_content)],
                    description="Trailing comma in select statement forbidden",
                )
        elif self.select_clause_trailing_comma == "require":
            if not last_content.is_type("comma"):
                new_comma = SymbolSegment(",", type="comma")
                return LintResult(
                    anchor=last_content,
                    fixes=[
                        LintFix.replace(last_content,
                                        [last_content, new_comma])
                    ],
                    description="Trailing comma in select statement required",
                )
        return None
Esempio n. 3
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Find rule violations and provide fixes.

        0. Look for a case expression
        1. Look for "ELSE"
        2. Mark "ELSE" for deletion (populate "fixes")
        3. Backtrack and mark all newlines/whitespaces for deletion
        4. Look for a raw "NULL" segment
        5.a. The raw "NULL" segment is found, we mark it for deletion and return
        5.b. We reach the end of case when without matching "NULL": the rule passes
        """
        assert context.segment.is_type("case_expression")
        children = FunctionalContext(context).segment.children()
        else_clause = children.first(sp.is_type("else_clause"))

        # Does the "ELSE" have a "NULL"? NOTE: Here, it's safe to look for
        # "NULL", as an expression would *contain* NULL but not be == NULL.
        if else_clause and else_clause.children(
                lambda child: child.raw_upper == "NULL"):
            # Found ELSE with NULL. Delete the whole else clause as well as
            # indents/whitespaces/meta preceding the ELSE. :TRICKY: Note
            # the use of reversed() to make select() effectively search in
            # reverse.
            before_else = children.reversed().select(
                start_seg=else_clause[0],
                loop_while=sp.or_(sp.is_name("whitespace", "newline"),
                                  sp.is_meta()),
            )
            return LintResult(
                anchor=context.segment,
                fixes=[LintFix.delete(else_clause[0])] +
                [LintFix.delete(seg) for seg in before_else],
            )
        return None
Esempio n. 4
0
    def _eval(self, context: RuleContext) -> LintResult:
        """Function name not immediately followed by bracket.

        Look for Function Segment with anything other than the
        function name before brackets
        """
        segment = FunctionalContext(context).segment
        # We only trigger on start_bracket (open parenthesis)
        assert segment.all(sp.is_type("function"))
        children = segment.children()

        function_name = children.first(sp.is_type("function_name"))[0]
        start_bracket = children.first(sp.is_type("bracketed"))[0]

        intermediate_segments = children.select(start_seg=function_name,
                                                stop_seg=start_bracket)
        if intermediate_segments:
            # It's only safe to fix if there is only whitespace
            # or newlines in the intervening section.
            if intermediate_segments.all(sp.is_type("whitespace", "newline")):
                return LintResult(
                    anchor=intermediate_segments[0],
                    fixes=[
                        LintFix.delete(seg) for seg in intermediate_segments
                    ],
                )
            else:
                # It's not all whitespace, just report the error.
                return LintResult(anchor=intermediate_segments[0], )
        return LintResult()
Esempio n. 5
0
    def _eval(self, context: RuleContext) -> LintResult:
        """Nested CASE statement in ELSE clause could be flattened."""
        segment = FunctionalContext(context).segment
        assert segment.select(sp.is_type("case_expression"))
        case1_children = segment.children()
        case1_last_when = case1_children.last(sp.is_type("when_clause")).get()
        case1_else_clause = case1_children.select(sp.is_type("else_clause"))
        case1_else_expressions = case1_else_clause.children(
            sp.is_type("expression"))
        expression_children = case1_else_expressions.children()
        case2 = expression_children.select(sp.is_type("case_expression"))
        # The len() checks below are for safety, to ensure the CASE inside
        # the ELSE is not part of a larger expression. In that case, it's
        # not safe to simplify in this way -- we'd be deleting other code.
        if (not case1_last_when or len(case1_else_expressions) > 1
                or len(expression_children) > 1 or not case2):
            return LintResult()

        # We can assert that this exists because of the previous check.
        assert case1_last_when
        # We can also assert that we'll also have an else clause because
        # otherwise the case2 check above would fail.
        case1_else_clause_seg = case1_else_clause.get()
        assert case1_else_clause_seg

        # Delete stuff between the last "WHEN" clause and the "ELSE" clause.
        case1_to_delete = case1_children.select(start_seg=case1_last_when,
                                                stop_seg=case1_else_clause_seg)

        # Delete the nested "CASE" expression.
        fixes = case1_to_delete.apply(lambda seg: LintFix.delete(seg))

        # Determine the indentation to use when we move the nested "WHEN"
        # and "ELSE" clauses, based on the indentation of case1_last_when.
        # If no whitespace segments found, use default indent.
        indent = (case1_children.select(
            stop_seg=case1_last_when).reversed().select(
                sp.is_type("whitespace")))
        indent_str = "".join(seg.raw
                             for seg in indent) if indent else self.indent

        # Move the nested "when" and "else" clauses after the last outer
        # "when".
        nested_clauses = case2.children(
            sp.is_type("when_clause", "else_clause"))
        create_after_last_when = nested_clauses.apply(
            lambda seg: [NewlineSegment(),
                         WhitespaceSegment(indent_str), seg])
        segments = [
            item for sublist in create_after_last_when for item in sublist
        ]
        fixes.append(
            LintFix.create_after(case1_last_when, segments, source=segments))

        # Delete the outer "else" clause.
        fixes.append(LintFix.delete(case1_else_clause_seg))
        return LintResult(case2[0], fixes=fixes)
Esempio n. 6
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Identify aliases in from clause and join conditions.

        Find base table, table expressions in join, and other expressions in select
        clause and decide if it's needed to report them.
        """
        # Config type hints
        self.force_enable: bool

        # Issue 2810: BigQuery has some tricky expectations (apparently not
        # documented, but subject to change, e.g.:
        # https://www.reddit.com/r/bigquery/comments/fgk31y/new_in_bigquery_no_more_backticks_around_table/)
        # about whether backticks are required (and whether the query is valid
        # or not, even with them), depending on whether the GCP project name is
        # present, or just the dataset name. Since SQLFluff doesn't have access
        # to BigQuery when it is looking at the query, it would be complex for
        # this rule to do the right thing. For now, the rule simply disables
        # itself.
        if (context.dialect.name in self._dialects_disabled_by_default
                and not self.force_enable):
            return LintResult()

        assert context.segment.is_type("select_statement")

        children = FunctionalContext(context).segment.children()
        from_clause_segment = children.select(
            sp.is_type("from_clause")).first()
        base_table = (from_clause_segment.children(
            sp.is_type("from_expression")).first().children(
                sp.is_type("from_expression_element")).first().children(
                    sp.is_type("table_expression")).first().children(
                        sp.is_type("object_reference")).first())
        if not base_table:
            return None

        # A buffer for all table expressions in join conditions
        from_expression_elements = []
        column_reference_segments = []

        after_from_clause = children.select(start_seg=from_clause_segment[0])
        for clause in from_clause_segment + after_from_clause:
            for from_expression_element in clause.recursive_crawl(
                    "from_expression_element"):
                from_expression_elements.append(from_expression_element)
            for column_reference in clause.recursive_crawl("column_reference"):
                column_reference_segments.append(column_reference)

        return (self._lint_aliases_in_join(
            base_table[0] if base_table else None,
            from_expression_elements,
            column_reference_segments,
            context.segment,
        ) or None)
Esempio n. 7
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Identify aliases in from clause and join conditions.

        Find base table, table expressions in join, and other expressions in select
        clause and decide if it's needed to report them.
        """
        self.min_alias_length: Optional[int]
        self.max_alias_length: Optional[int]

        assert context.segment.is_type("select_statement")
        children = FunctionalContext(context).segment.children()
        from_expression_elements = children.recursive_crawl("from_expression_element")

        return self._lint_aliases(from_expression_elements) or None
Esempio n. 8
0
 def _eval(self, context: RuleContext) -> Optional[LintResult]:
     """Ambiguous use of DISTINCT in select statement with GROUP BY."""
     segment = FunctionalContext(context).segment
     # We know it's a select_statement from the seeker crawler
     assert segment.all(sp.is_type("select_statement"))
     # Do we have a group by clause
     if segment.children(sp.is_type("groupby_clause")):
         # Do we have the "DISTINCT" keyword in the select clause
         distinct = (segment.children(sp.is_type("select_clause")).children(
             sp.is_type("select_clause_modifier")).children(
                 sp.is_type("keyword")).select(sp.is_name("distinct")))
         if distinct:
             return LintResult(anchor=distinct[0])
     return None
Esempio n. 9
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Use ``!=`` instead of ``<>`` for "not equal to" comparison."""
        # Only care about not_equal_to segments. We should only get
        # comparison operator types from the crawler, but not all will
        # be "not_equal_to".
        if context.segment.name != "not_equal_to":
            return None

        # Get the comparison operator children
        raw_comparison_operators = (
            FunctionalContext(context)
            .segment.children()
            .select(select_if=sp.is_type("raw_comparison_operator"))
        )

        # Only care about ``<>``
        if [r.raw for r in raw_comparison_operators] != ["<", ">"]:
            return None

        # Provide a fix and replace ``<>`` with ``!=``
        # As each symbol is a separate symbol this is done in two steps:
        # 1. Replace < with !
        # 2. Replace > with =
        fixes = [
            LintFix.replace(
                raw_comparison_operators[0],
                [SymbolSegment(raw="!", type="raw_comparison_operator")],
            ),
            LintFix.replace(
                raw_comparison_operators[1],
                [SymbolSegment(raw="=", type="raw_comparison_operator")],
            ),
        ]

        return LintResult(context.segment, fixes)
Esempio n. 10
0
    def _eval(self, context: RuleContext) -> EvalResultType:
        # Config type hints
        self.force_enable: bool

        if (context.dialect.name in self._dialects_disabled_by_default
                and not self.force_enable):
            return LintResult()

        violations: List[LintResult] = []
        if not FunctionalContext(context).parent_stack.any(
                sp.is_type(*_START_TYPES)):
            dml_target_table: Optional[Tuple[str, ...]] = None
            self.logger.debug("Trigger on: %s", context.segment)
            if not context.segment.is_type("select_statement"):
                # Extract first table reference. This will be the target
                # table in a DML statement.
                table_reference = next(
                    context.segment.recursive_crawl("table_reference"), None)
                if table_reference:
                    dml_target_table = self._table_ref_as_tuple(
                        table_reference)

            self.logger.debug("DML Reference Table: %s", dml_target_table)
            # Verify table references in any SELECT statements found in or
            # below context.segment in the parser tree.
            crawler = SelectCrawler(context.segment,
                                    context.dialect,
                                    query_class=L026Query)
            query: L026Query = cast(L026Query, crawler.query_tree)
            if query:
                self._analyze_table_references(query, dml_target_table,
                                               context.dialect, violations)
        return violations or None
Esempio n. 11
0
 def _eval(self, context: RuleContext) -> Optional[LintResult]:
     # T-SQL supports alternative alias expressions for L012
     # select alias = value
     # instead of
     # select value as alias
     # Recognise this and exit early
     if FunctionalContext(context).segment.children()[-1].raw == "=":
         return None
     return super()._eval(context)
Esempio n. 12
0
 def _eval(self, context: RuleContext):
     self.wildcard_policy: str
     assert context.segment.is_type("select_clause")
     select_targets_info = self._get_indexes(context)
     select_clause = FunctionalContext(context).segment
     wildcards = select_clause.children(
         sp.is_type("select_clause_element")
     ).children(sp.is_type("wildcard_expression"))
     has_wildcard = bool(wildcards)
     if len(select_targets_info.select_targets) == 1 and (
         not has_wildcard or self.wildcard_policy == "single"
     ):
         return self._eval_single_select_target_element(
             select_targets_info,
             context,
         )
     elif len(select_targets_info.select_targets):
         return self._eval_multiple_select_target_elements(
             select_targets_info, context.segment
         )
Esempio n. 13
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Outermost query should produce known number of columns."""
        if not FunctionalContext(context).parent_stack.any(sp.is_type(*_START_TYPES)):
            crawler = SelectCrawler(context.segment, context.dialect)

            # Begin analysis at the outer query.
            if crawler.query_tree:
                try:
                    return self._analyze_result_columns(crawler.query_tree)
                except RuleFailure as e:
                    return LintResult(anchor=e.anchor)
        return None
Esempio n. 14
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Looking for DISTINCT before a bracket.

        Look for DISTINCT keyword immediately followed by open parenthesis.
        """
        # We trigger on `select_clause` and look for `select_clause_modifier`
        assert context.segment.is_type("select_clause")
        children = FunctionalContext(context).segment.children()
        modifier = children.select(sp.is_type("select_clause_modifier"))
        first_element = children.select(
            sp.is_type("select_clause_element")).first()
        if not modifier or not first_element:
            return None
        # is the first element only an expression with only brackets?
        expression = (first_element.children(sp.is_type("expression")).first()
                      or first_element)
        bracketed = expression.children(sp.is_type("bracketed")).first()
        if bracketed:
            fixes = []
            # If there's nothing else in the expression, remove the brackets.
            if len(expression[0].segments) == 1:
                # Remove the brackets and strip any meta segments.
                fixes.append(
                    LintFix.replace(
                        bracketed[0],
                        self.filter_meta(bracketed[0].segments)[1:-1]), )
            # If no whitespace between DISTINCT and expression, add it.
            if not children.select(sp.is_whitespace(),
                                   start_seg=modifier[0],
                                   stop_seg=first_element[0]):
                fixes.append(
                    LintFix.create_before(
                        first_element[0],
                        [WhitespaceSegment()],
                    ))
            # If no fixes, no problem.
            if fixes:
                return LintResult(anchor=modifier[0], fixes=fixes)
        return None
Esempio n. 15
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Column expression without alias. Use explicit `AS` clause.

        We look for the select_clause_element segment, and then evaluate
        whether it has an alias segment or not and whether the expression
        is complicated enough. `parent_stack` is to assess how many other
        elements there are.

        """
        functional_context = FunctionalContext(context)
        segment = functional_context.segment
        children = segment.children()
        # If we have an alias its all good
        if children.any(sp.is_type("alias_expression")):
            return None

        # Ignore if it's a function with EMITS clause as EMITS is equivalent to AS
        if (children.select(sp.is_type("function")).children().select(
                sp.is_type("emits_segment"))):
            return None

        parent_stack = functional_context.parent_stack

        # Ignore if it is part of a CTE with column names
        if (parent_stack.last(
                sp.is_type("common_table_expression")).children().any(
                    sp.is_type("cte_column_list"))):
            return None

        select_clause_children = children.select(sp.not_(sp.is_name("star")))
        is_complex_clause = _recursively_check_is_complex(
            select_clause_children)
        if not is_complex_clause:
            return None
        # No fixes, because we don't know what the alias should be,
        # the user should document it themselves.
        if self.allow_scalar:  # type: ignore
            # Check *how many* elements/columns there are in the select
            # statement. If this is the only one, then we won't
            # report an error.
            immediate_parent = parent_stack.last()
            elements = immediate_parent.children(
                sp.is_type("select_clause_element"))
            num_elements = len(elements)

            if num_elements > 1:
                return LintResult(anchor=context.segment)
            return None

        return LintResult(anchor=context.segment)
Esempio n. 16
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Files must end with a single trailing newline.

        We only care about the segment and the siblings which come after it
        for this rule, we discard the others into the kwargs argument.

        """
        # We only care about the final segment of the parse tree.
        parent_stack, segment = get_last_segment(FunctionalContext(context).segment)
        self.logger.debug("Found last segment as: %s", segment)

        trailing_newlines = Segments(*get_trailing_newlines(context.segment))
        trailing_literal_newlines = trailing_newlines
        self.logger.debug(
            "Untemplated trailing newlines: %s", trailing_literal_newlines
        )
        if context.templated_file:
            trailing_literal_newlines = trailing_newlines.select(
                loop_while=lambda seg: sp.templated_slices(
                    seg, context.templated_file
                ).all(tsp.is_slice_type("literal"))
            )
        self.logger.debug("Templated trailing newlines: %s", trailing_literal_newlines)
        if not trailing_literal_newlines:
            # We make an edit to create this segment after the child of the FileSegment.
            if len(parent_stack) == 1:
                fix_anchor_segment = segment[0]
            else:
                fix_anchor_segment = parent_stack[1]
            self.logger.debug("Anchor on: %s", fix_anchor_segment)

            return LintResult(
                anchor=segment[0],
                fixes=[
                    LintFix.create_after(
                        fix_anchor_segment,
                        [NewlineSegment()],
                    )
                ],
            )
        elif len(trailing_literal_newlines) > 1:
            # Delete extra newlines.
            return LintResult(
                anchor=segment[0],
                fixes=[LintFix.delete(d) for d in trailing_literal_newlines[1:]],
            )
        else:
            # Single newline, no need for fix.
            return None
Esempio n. 17
0
    def _eval(self, context: RuleContext) -> LintResult:
        """Unnecessary trailing whitespace.

        Look for newline segments, and then evaluate what
        it was preceded by.
        """
        if len(context.raw_stack) > 0 and context.raw_stack[-1].is_type(
                "whitespace"):
            # Look for a newline (or file end), which is preceded by whitespace
            deletions = (
                FunctionalContext(context).raw_stack.reversed().select(
                    loop_while=sp.is_type("whitespace")))
            # NOTE: The presence of a loop marker should prevent false
            # flagging of newlines before jinja loop tags.
            return LintResult(
                anchor=deletions[-1],
                fixes=[LintFix.delete(d) for d in deletions],
            )
        return LintResult()
Esempio n. 18
0
    def _eval(self, context: RuleContext) -> EvalResultType:
        """Override base class for dialects that use structs, or SELECT aliases."""
        # Config type hints
        self.force_enable: bool
        # Some dialects use structs (e.g. column.field) which look like
        # table references and so incorrectly trigger this rule.
        if (context.dialect.name in self._dialects_with_structs
                and not self.force_enable):
            return LintResult()

        if context.dialect.name in self._dialects_with_structs:
            self._is_struct_dialect = True

        if not FunctionalContext(context).parent_stack.any(
                sp.is_type(*_START_TYPES)):
            crawler = SelectCrawler(context.segment, context.dialect)
            if crawler.query_tree:
                # Recursively visit and check each query in the tree.
                return list(self._visit_queries(crawler.query_tree))
        return None
Esempio n. 19
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Unnecessary CASE statement."""
        # Look for CASE expression.
        if context.segment.segments[0].raw_upper == "CASE":
            # Find all 'WHEN' clauses and the optional 'ELSE' clause.
            children = FunctionalContext(context).segment.children()
            when_clauses = children.select(sp.is_type("when_clause"))
            else_clauses = children.select(sp.is_type("else_clause"))

            # Can't fix if multiple WHEN clauses.
            if len(when_clauses) > 1:
                return None

            # Find condition and then expressions.
            condition_expression = when_clauses.children(
                sp.is_type("expression"))[0]
            then_expression = when_clauses.children(
                sp.is_type("expression"))[1]

            # Method 1: Check if THEN/ELSE expressions are both Boolean and can
            # therefore be reduced.
            if else_clauses:
                else_expression = else_clauses.children(
                    sp.is_type("expression"))[0]
                upper_bools = ["TRUE", "FALSE"]
                if ((then_expression.raw_upper in upper_bools)
                        and (else_expression.raw_upper in upper_bools) and
                    (then_expression.raw_upper != else_expression.raw_upper)):
                    coalesce_arg_1: BaseSegment = condition_expression
                    coalesce_arg_2: BaseSegment = KeywordSegment("false")
                    preceding_not = then_expression.raw_upper == "FALSE"

                    fixes = self._coalesce_fix_list(
                        context,
                        coalesce_arg_1,
                        coalesce_arg_2,
                        preceding_not,
                    )

                    return LintResult(
                        anchor=condition_expression,
                        fixes=fixes,
                        description="Unnecessary CASE statement. "
                        "Use COALESCE function instead.",
                    )

            # Method 2: Check if the condition expression is comparing a column
            # reference to NULL and whether that column reference is also in either the
            # THEN/ELSE expression. We can only apply this method when there is only
            # one condition in the condition expression.
            condition_expression_segments_raw = {
                segment.raw_upper
                for segment in condition_expression.segments
            }
            if {"IS", "NULL"}.issubset(condition_expression_segments_raw) and (
                    not condition_expression_segments_raw.intersection(
                        {"AND", "OR"})):
                # Check if the comparison is to NULL or NOT NULL.
                is_not_prefix = "NOT" in condition_expression_segments_raw

                # Locate column reference in condition expression.
                column_reference_segment = (
                    Segments(condition_expression).children(
                        sp.is_type("column_reference")).get())

                # Return None if none found (this condition does not apply to functions)
                if not column_reference_segment:
                    return None

                if else_clauses:
                    else_expression = else_clauses.children(
                        sp.is_type("expression"))[0]
                    # Check if we can reduce the CASE expression to a single coalesce
                    # function.
                    if (not is_not_prefix
                            and column_reference_segment.raw_upper
                            == else_expression.raw_upper):
                        coalesce_arg_1 = else_expression
                        coalesce_arg_2 = then_expression
                    elif (is_not_prefix and column_reference_segment.raw_upper
                          == then_expression.raw_upper):
                        coalesce_arg_1 = then_expression
                        coalesce_arg_2 = else_expression
                    else:
                        return None

                    if coalesce_arg_2.raw_upper == "NULL":
                        # Can just specify the column on it's own
                        # rather than using a COALESCE function.
                        return LintResult(
                            anchor=condition_expression,
                            fixes=self._column_only_fix_list(
                                context,
                                column_reference_segment,
                            ),
                            description="Unnecessary CASE statement. "
                            f"Just use column '{column_reference_segment.raw}'.",
                        )

                    return LintResult(
                        anchor=condition_expression,
                        fixes=self._coalesce_fix_list(
                            context,
                            coalesce_arg_1,
                            coalesce_arg_2,
                        ),
                        description="Unnecessary CASE statement. "
                        "Use COALESCE function instead.",
                    )
                elif column_reference_segment.raw_upper == then_expression.raw_upper:
                    # Can just specify the column on it's own
                    # rather than using a COALESCE function.
                    # In this case no ELSE statement is equivalent to ELSE NULL.
                    return LintResult(
                        anchor=condition_expression,
                        fixes=self._column_only_fix_list(
                            context,
                            column_reference_segment,
                        ),
                        description="Unnecessary CASE statement. "
                        f"Just use column '{column_reference_segment.raw}'.",
                    )

        return None
Esempio n. 20
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Find rule violations and provide fixes."""
        # Config type hints
        self.prefer_count_0: bool
        self.prefer_count_1: bool

        if (
                # We already know we're in a function because of the crawl_behaviour
                context.segment.get_child("function_name").raw_upper == "COUNT"
        ):
            # Get bracketed content
            f_content = (FunctionalContext(context).segment.children(
                sp.is_type("bracketed")).children(
                    sp.and_(
                        sp.not_(sp.is_meta()),
                        sp.not_(
                            sp.is_type("start_bracket", "end_bracket",
                                       "whitespace", "newline")),
                    )))
            if len(f_content) != 1:  # pragma: no cover
                return None

            preferred = "*"
            if self.prefer_count_1:
                preferred = "1"
            elif self.prefer_count_0:
                preferred = "0"

            if f_content[0].is_type("star") and (self.prefer_count_1
                                                 or self.prefer_count_0):
                return LintResult(
                    anchor=context.segment,
                    fixes=[
                        LintFix.replace(
                            f_content[0],
                            [
                                f_content[0].edit(f_content[0].raw.replace(
                                    "*", preferred))
                            ],
                        ),
                    ],
                )

            if f_content[0].is_type("expression"):
                expression_content = [
                    seg for seg in f_content[0].segments if not seg.is_meta
                ]

                if (len(expression_content) == 1
                        and expression_content[0].is_type("literal")
                        and expression_content[0].raw in ["0", "1"]
                        and expression_content[0].raw != preferred):
                    return LintResult(
                        anchor=context.segment,
                        fixes=[
                            LintFix.replace(
                                expression_content[0],
                                [
                                    expression_content[0].edit(
                                        expression_content[0].raw.replace(
                                            expression_content[0].raw,
                                            preferred)),
                                ],
                            ),
                        ],
                    )
        return None
Esempio n. 21
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Inconsistent column references in GROUP BY/ORDER BY clauses."""
        # Config type hints
        self.group_by_and_order_by_style: str

        # We only care about GROUP BY/ORDER BY clauses.
        assert context.segment.is_type("groupby_clause", "orderby_clause")

        # Ignore Windowing clauses
        if FunctionalContext(context).parent_stack.any(
                sp.is_type(*self._ignore_types)):
            return LintResult(memory=context.memory)

        # Look at child segments and map column references to either the implicit or
        # explicit category.
        # N.B. segment names are used as the numeric literal type is 'raw', so best to
        # be specific with the name.
        column_reference_category_map = {
            "column_reference": "explicit",
            "expression": "explicit",
            "numeric_literal": "implicit",
        }
        column_reference_category_set = {
            column_reference_category_map[segment.get_type()]
            for segment in context.segment.segments
            if segment.is_type(*column_reference_category_map.keys())
        }

        # If there are no column references then just return
        if not column_reference_category_set:  # pragma: no cover
            return LintResult(memory=context.memory)

        if self.group_by_and_order_by_style == "consistent":
            # If consistent naming then raise lint error if either:

            if len(column_reference_category_set) > 1:
                # 1. Both implicit and explicit column references are found in the same
                # clause.
                return LintResult(
                    anchor=context.segment,
                    memory=context.memory,
                )
            else:
                # 2. A clause is found to contain column name references that
                #    contradict the precedent set in earlier clauses.
                current_group_by_order_by_convention = (
                    column_reference_category_set.pop())
                prior_group_by_order_by_convention = context.memory.get(
                    "prior_group_by_order_by_convention")

                if prior_group_by_order_by_convention and (
                        prior_group_by_order_by_convention !=
                        current_group_by_order_by_convention):
                    return LintResult(
                        anchor=context.segment,
                        memory=context.memory,
                    )

                context.memory[
                    "prior_group_by_order_by_convention"] = current_group_by_order_by_convention
        else:
            # If explicit or implicit naming then raise lint error
            # if the opposite reference type is detected.
            if any(category != self.group_by_and_order_by_style
                   for category in column_reference_category_set):
                return LintResult(
                    anchor=context.segment,
                    memory=context.memory,
                )

        # Return memory for later clauses.
        return LintResult(memory=context.memory)
Esempio n. 22
0
    def _eval(self, context: RuleContext) -> List[LintResult]:
        """Set operators should be surrounded by newlines.

        For any set operator we check if there is any NewLineSegment in the non-code
        segments preceeding or following it.

        In particular, as part of this rule we allow multiple NewLineSegments.
        """
        segment = FunctionalContext(context).segment

        expression = segment.children()
        set_operator_segments = segment.children(sp.is_type(*self._target_elems))
        # We should always find some as children because of the ParentOfSegmentCrawler
        assert set_operator_segments
        results: List[LintResult] = []

        # If len(set_operator) == 0 this will essentially not run
        for set_operator in set_operator_segments:
            preceeding_code = (
                expression.reversed().select(start_seg=set_operator).first(sp.is_code())
            )
            following_code = expression.select(start_seg=set_operator).first(
                sp.is_code()
            )
            res = {
                "before": expression.select(
                    start_seg=preceeding_code.get(), stop_seg=set_operator
                ),
                "after": expression.select(
                    start_seg=set_operator, stop_seg=following_code.get()
                ),
            }

            newline_before_set_operator = res["before"].first(sp.is_type("newline"))
            newline_after_set_operator = res["after"].first(sp.is_type("newline"))

            # If there is a whitespace directly preceeding/following the set operator we
            # are replacing it with a newline later.
            preceeding_whitespace = res["before"].first(sp.is_type("whitespace")).get()
            following_whitespace = res["after"].first(sp.is_type("whitespace")).get()

            if newline_before_set_operator and newline_after_set_operator:
                continue
            elif not newline_before_set_operator and newline_after_set_operator:
                results.append(
                    LintResult(
                        anchor=set_operator,
                        description=(
                            "Set operators should be surrounded by newlines. "
                            f"Missing newline before set operator {set_operator.raw}."
                        ),
                        fixes=_generate_fixes(whitespace_segment=preceeding_whitespace),
                    )
                )
            elif newline_before_set_operator and not newline_after_set_operator:
                results.append(
                    LintResult(
                        anchor=set_operator,
                        description=(
                            "Set operators should be surrounded by newlines. "
                            f"Missing newline after set operator {set_operator.raw}."
                        ),
                        fixes=_generate_fixes(whitespace_segment=following_whitespace),
                    )
                )
            else:
                preceeding_whitespace_fixes = _generate_fixes(
                    whitespace_segment=preceeding_whitespace
                )
                following_whitespace_fixes = _generate_fixes(
                    whitespace_segment=following_whitespace
                )

                # make mypy happy
                assert isinstance(preceeding_whitespace_fixes, Iterable)
                assert isinstance(following_whitespace_fixes, Iterable)

                fixes = []
                fixes.extend(preceeding_whitespace_fixes)
                fixes.extend(following_whitespace_fixes)

                results.append(
                    LintResult(
                        anchor=set_operator,
                        description=(
                            "Set operators should be surrounded by newlines. "
                            "Missing newline before and after set operator "
                            f"{set_operator.raw}."
                        ),
                        fixes=fixes,
                    )
                )

        return results
Esempio n. 23
0
    def _eval(self, context: RuleContext):
        """WITH clause closing bracket should be aligned with WITH keyword.

        Look for a with clause and evaluate the position of closing brackets.
        """
        # We only trigger on start_bracket (open parenthesis)
        assert context.segment.is_type("with_compound_statement")
        raw_stack_buff = list(context.raw_stack)
        # Look for the with keyword
        for seg in context.segment.segments:
            if seg.name.lower() == "with":
                seg_line_no = seg.pos_marker.line_no
                break
        else:  # pragma: no cover
            # This *could* happen if the with statement is unparsable,
            # in which case then the user will have to fix that first.
            if any(s.is_type("unparsable") for s in context.segment.segments):
                return LintResult()
            # If it's parsable but we still didn't find a with, then
            # we should raise that.
            raise RuntimeError("Didn't find WITH keyword!")

        # Find the end brackets for the CTE *query* (i.e. ignore optional
        # list of CTE columns).
        cte_end_brackets = IdentitySet()
        for cte in (FunctionalContext(context).segment.children(
                sp.is_type("common_table_expression")).iterate_segments()):
            cte_end_bracket = (cte.children().last(
                sp.is_type("bracketed")).children().last(
                    sp.is_type("end_bracket")))
            if cte_end_bracket:
                cte_end_brackets.add(cte_end_bracket[0])
        for seg in context.segment.iter_segments(
                expanding=["common_table_expression", "bracketed"],
                pass_through=True):
            if seg not in cte_end_brackets:
                if not seg.is_type("start_bracket"):
                    raw_stack_buff.append(seg)
                continue

            if seg.pos_marker.line_no == seg_line_no:
                # Skip if it's the one-line version. That's ok
                continue

            # Is it all whitespace before the bracket on this line?
            assert seg.pos_marker

            contains_non_whitespace = False
            for elem in context.segment.raw_segments:
                if (cast(PositionMarker,
                         elem.pos_marker).line_no == seg.pos_marker.line_no
                        and cast(PositionMarker, elem.pos_marker).line_pos <=
                        seg.pos_marker.line_pos):
                    if elem is seg:
                        break
                    elif elem.is_type("newline"):
                        contains_non_whitespace = False
                    elif not elem.is_type("dedent") and not elem.is_type(
                            "whitespace"):
                        contains_non_whitespace = True

            if contains_non_whitespace:
                # We have to move it to a newline
                return LintResult(
                    anchor=seg,
                    fixes=[LintFix.create_before(
                        seg,
                        [
                            NewlineSegment(),
                        ],
                    )],
                )
Esempio n. 24
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Unnecessary quoted identifier."""
        # Config type hints
        self.prefer_quoted_identifiers: bool
        self.ignore_words: str
        self.ignore_words_regex: str
        self.force_enable: bool
        # Some dialects allow quotes as PART OF the column name. In other words,
        # these are two different columns:
        # - date
        # - "date"
        # For safety, disable this rule by default in those dialects.
        if (context.dialect.name
                in self._dialects_allowing_quotes_in_column_names
                and not self.force_enable):
            return LintResult()

        # Ignore some segment types
        if FunctionalContext(context).parent_stack.any(
                sp.is_type(*self._ignore_types)):
            return None

        if self.prefer_quoted_identifiers:
            context_policy = "naked_identifier"
            identifier_contents = context.segment.raw
        else:
            context_policy = "quoted_identifier"
            identifier_contents = context.segment.raw[1:-1]

        # Get the ignore_words_list configuration.
        try:
            ignore_words_list = self.ignore_words_list
        except AttributeError:
            # First-time only, read the settings from configuration. This is
            # very slow.
            ignore_words_list = self._init_ignore_words_list()

        # Skip if in ignore list
        if ignore_words_list and identifier_contents.lower(
        ) in ignore_words_list:
            return None

        # Skip if matches ignore regex
        if self.ignore_words_regex and regex.search(self.ignore_words_regex,
                                                    identifier_contents):
            return LintResult(memory=context.memory)

        # Ignore the segments that are not of the same type as the defined policy above.
        # Also TSQL has a keyword called QUOTED_IDENTIFIER which maps to the name so
        # need to explicity check for that.
        if not context.segment.is_type(
                context_policy) or context.segment.raw.lower() in (
                    "quoted_identifier",
                    "naked_identifier",
                ):
            return None

        # Manage cases of identifiers must be quoted first.
        # Naked identifiers are _de facto_ making this rule fail as configuration forces
        # them to be quoted.
        # In this case, it cannot be fixed as which quote to use is dialect dependent
        if self.prefer_quoted_identifiers:
            return LintResult(
                context.segment,
                description=f"Missing quoted identifier {identifier_contents}.",
            )

        # Now we only deal with NOT forced quoted identifiers configuration
        # (meaning prefer_quoted_identifiers=False).

        # Extract contents of outer quotes.
        quoted_identifier_contents = context.segment.raw[1:-1]

        # Retrieve NakedIdentifierSegment RegexParser for the dialect.
        naked_identifier_parser = context.dialect._library[
            "NakedIdentifierSegment"]
        IdentifierSegment = cast(
            Type[CodeSegment],
            context.dialect.get_segment("IdentifierSegment"))

        # Check if quoted_identifier_contents could be a valid naked identifier
        # and that it is not a reserved keyword.
        if (regex.fullmatch(
                naked_identifier_parser.template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is not None) and (regex.fullmatch(
                naked_identifier_parser.anti_template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is None):
            return LintResult(
                context.segment,
                fixes=[
                    LintFix.replace(
                        context.segment,
                        [
                            IdentifierSegment(
                                raw=quoted_identifier_contents,
                                type="naked_identifier",
                            )
                        ],
                    )
                ],
                description=
                f"Unnecessary quoted identifier {context.segment.raw}.",
            )

        return None
Esempio n. 25
0
    def _eval(self, context: RuleContext) -> EvalResultType:
        """Look for non-literal segments."""
        assert context.segment.pos_marker
        if context.segment.is_raw(
        ) and not context.segment.pos_marker.is_literal():
            if not context.memory:
                memory = set()
            else:
                memory = context.memory

            # Get any templated raw slices.
            # NOTE: We use this function because a single segment
            # may include multiple raw templated sections:
            # e.g. a single identifier with many templated tags.
            templated_raw_slices = FunctionalContext(
                context).segment.raw_slices.select(
                    rsp.is_slice_type("templated", "block_start", "block_end"))
            result = []

            # Iterate through any tags found.
            for raw_slice in templated_raw_slices:
                stripped = raw_slice.raw.strip()
                if not stripped or stripped[0] != "{" or stripped[-1] != "}":
                    continue  # pragma: no cover

                self.logger.debug("Tag found @ %s: %r ",
                                  context.segment.pos_marker, stripped)

                # Dedupe using a memory of source indexes.
                # This is important because several positions in the
                # templated file may refer to the same position in the
                # source file and we only want to get one violation.
                src_idx = raw_slice.source_idx
                if context.memory and src_idx in context.memory:
                    continue
                memory.add(src_idx)

                # Partition and Position
                tag_pre, ws_pre, inner, ws_post, tag_post = self._get_whitespace_ends(
                    stripped)
                position = raw_slice.raw.find(stripped[0])

                self.logger.debug(
                    "Tag string segments: %r | %r | %r | %r | %r @ %s + %s",
                    tag_pre,
                    ws_pre,
                    inner,
                    ws_post,
                    tag_post,
                    src_idx,
                    position,
                )

                # For the following section, whitespace should be a single
                # whitespace OR it should contain a newline.

                pre_fix = None
                post_fix = None
                # Check the initial whitespace.
                if not ws_pre or (ws_pre != " " and "\n" not in ws_pre):
                    pre_fix = " "
                # Check latter whitespace.
                if not ws_post or (ws_post != " " and "\n" not in ws_post):
                    post_fix = " "

                if pre_fix is not None or post_fix is not None:
                    fixed = (tag_pre + (pre_fix or ws_pre) + inner +
                             (post_fix or ws_post) + tag_post)
                    src_fix = [
                        SourceFix(
                            fixed,
                            slice(
                                src_idx + position,
                                src_idx + position + len(stripped),
                            ),
                            # NOTE: The templated slice here is
                            # going to be a little imprecise, but
                            # the one that really matters is the
                            # source slice.
                            context.segment.pos_marker.templated_slice,
                        )
                    ]
                    result.append(
                        LintResult(
                            memory=memory,
                            anchor=context.segment,
                            description=f"Jinja tags should have a single "
                            f"whitespace on either side: {stripped}",
                            fixes=[
                                LintFix.replace(
                                    context.segment,
                                    [
                                        context.segment.edit(
                                            source_fixes=src_fix)
                                    ],
                                )
                            ],
                        ))
            if result:
                return result
            else:
                return LintResult(memory=memory)
        return LintResult(memory=context.memory)
Esempio n. 26
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        # Config type hints
        self.preferred_quoted_literal_style: str
        self.force_enable: bool

        # Only care about quoted literal segments.
        if not context.segment.is_type("quoted_literal"):
            return None

        if not (self.force_enable or context.dialect.name
                in self._dialects_with_double_quoted_strings):
            return LintResult(memory=context.memory)

        # This rule can also cover quoted literals that are partially templated.
        # I.e. when the quotes characters are _not_ part of the template we can
        # meaningfully apply this rule.
        templated_raw_slices = FunctionalContext(
            context).segment.raw_slices.select(rsp.is_slice_type("templated"))
        for raw_slice in templated_raw_slices:
            pos_marker = context.segment.pos_marker
            # This is to make mypy happy.
            assert isinstance(pos_marker, PositionMarker)

            # Check whether the quote characters are inside the template.
            # For the leading quote we need to account for string prefix characters.
            leading_quote_inside_template = pos_marker.source_str()[:2].lstrip(
                self._string_prefix_chars)[0] not in ['"', "'"]
            trailing_quote_inside_template = pos_marker.source_str(
            )[-1] not in [
                '"',
                "'",
            ]

            # quotes are not entirely outside of a template, nothing we can do
            if leading_quote_inside_template or trailing_quote_inside_template:
                return LintResult(memory=context.memory)

        # If quoting style is set to consistent we use the quoting style of the first
        # quoted_literal that we encounter.
        if self.preferred_quoted_literal_style == "consistent":
            memory = context.memory
            preferred_quoted_literal_style = memory.get(
                "preferred_quoted_literal_style")

            if not preferred_quoted_literal_style:
                # Getting the quote from LAST character to be able to handle STRING
                # prefixes
                preferred_quoted_literal_style = ("double_quotes"
                                                  if context.segment.raw[-1]
                                                  == '"' else "single_quotes")
                memory[
                    "preferred_quoted_literal_style"] = preferred_quoted_literal_style
                self.logger.debug(
                    "Preferred string quotes is set to `consistent`. Derived quoting "
                    "style %s from first quoted literal.",
                    preferred_quoted_literal_style,
                )
        else:
            preferred_quoted_literal_style = self.preferred_quoted_literal_style

        fixed_string = self._normalize_preferred_quoted_literal_style(
            context.segment.raw,
            preferred_quote_char=self._quotes_mapping[
                preferred_quoted_literal_style]["preferred_quote_char"],
            alternate_quote_char=self._quotes_mapping[
                preferred_quoted_literal_style]["alternate_quote_char"],
        )

        if fixed_string != context.segment.raw:
            return LintResult(
                anchor=context.segment,
                memory=context.memory,
                fixes=[
                    LintFix.replace(
                        context.segment,
                        [
                            LiteralSegment(
                                raw=fixed_string,
                                type="quoted_literal",
                            )
                        ],
                    )
                ],
                description=(
                    "Inconsistent use of preferred quote style '"
                    f"{self._quotes_mapping[preferred_quoted_literal_style]['common_name']}"  # noqa: E501
                    f"'. Use {fixed_string} instead of {context.segment.raw}."
                ),
            )

        return None
Esempio n. 27
0
    def _get_indexes(context: RuleContext):
        children = FunctionalContext(context).segment.children()
        select_targets = children.select(sp.is_type("select_clause_element"))
        first_select_target_idx = children.find(select_targets.get())
        selects = children.select(sp.is_keyword("select"))
        select_idx = children.find(selects.get()) if selects else -1
        newlines = children.select(sp.is_type("newline"))
        first_new_line_idx = children.find(newlines.get()) if newlines else -1
        comment_after_select_idx = -1
        if newlines:
            comment_after_select = children.select(
                sp.is_type("comment"),
                start_seg=selects.get(),
                stop_seg=newlines.get(),
                loop_while=sp.or_(
                    sp.is_type("comment"), sp.is_type("whitespace"), sp.is_meta()
                ),
            )
            if comment_after_select:
                comment_after_select_idx = (
                    children.find(comment_after_select.get())
                    if comment_after_select
                    else -1
                )
        first_whitespace_idx = -1
        if first_new_line_idx != -1:
            # TRICKY: Ignore whitespace prior to the first newline, e.g. if
            # the line with "SELECT" (before any select targets) has trailing
            # whitespace.
            segments_after_first_line = children.select(
                sp.is_type("whitespace"), start_seg=children[first_new_line_idx]
            )
            first_whitespace_idx = children.find(segments_after_first_line.get())

        siblings_post = FunctionalContext(context).siblings_post
        from_segment = siblings_post.first(sp.is_type("from_clause")).first().get()
        pre_from_whitespace = siblings_post.select(
            sp.is_type("whitespace"), stop_seg=from_segment
        )
        return SelectTargetsInfo(
            select_idx,
            first_new_line_idx,
            first_select_target_idx,
            first_whitespace_idx,
            comment_after_select_idx,
            select_targets,
            from_segment,
            list(pre_from_whitespace),
        )
Esempio n. 28
0
    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Relational operators should not be used to check for NULL values."""
        # Context/motivation for this rule:
        # https://news.ycombinator.com/item?id=28772289
        # https://stackoverflow.com/questions/9581745/sql-is-null-and-null
        if len(context.segment.segments) <= 2:
            return None  # pragma: no cover

        # Allow assignments in SET clauses
        if context.parent_stack and context.parent_stack[-1].is_type(
                "set_clause_list", "execute_script_statement"):
            return None

        # Allow assignments in EXEC clauses
        if context.segment.is_type("set_clause_list",
                                   "execute_script_statement"):
            return None

        segment = FunctionalContext(context).segment
        # Iterate through children of this segment looking for equals or "not
        # equals". Once found, check if the next code segment is a NULL literal.

        children = segment.children()
        operators = segment.children(sp.raw_is("=", "!=", "<>"))
        if len(operators) == 0:
            return None
        self.logger.debug("Operators found: %s", operators)

        results: List[LintResult] = []
        # We may have many operators
        for operator in operators:
            self.logger.debug("Children found: %s", children)
            after_op_list = children.select(start_seg=operator)
            # If nothing comes after operator then skip
            if not after_op_list:
                continue  # pragma: no cover
            null_literal = after_op_list.first(sp.is_code())
            # if the next bit of code isnt a NULL then we are good
            if not null_literal.all(sp.is_type("null_literal")):
                continue

            sub_seg = null_literal.get()
            assert sub_seg, "TypeGuard: Segment must exist"
            self.logger.debug(
                "Found NULL literal following equals/not equals @%s: %r",
                sub_seg.pos_marker,
                sub_seg.raw,
            )
            edit = _create_base_is_null_sequence(
                is_upper=sub_seg.raw[0] == "N",
                operator_raw=operator.raw,
            )
            prev_seg = after_op_list.first().get()
            next_seg = children.select(stop_seg=operator).last().get()
            if self._missing_whitespace(prev_seg, before=True):
                whitespace_segment: CorrectionListType = [WhitespaceSegment()]
                edit = whitespace_segment + edit
            if self._missing_whitespace(next_seg, before=False):
                edit = edit + [WhitespaceSegment()]
            res = LintResult(
                anchor=operator,
                fixes=[LintFix.replace(
                    operator,
                    edit,
                )],
            )
            results.append(res)

        return results or None
Esempio n. 29
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Select clause modifiers must appear on same line as SELECT."""
        # We only care about select_clause.
        assert context.segment.is_type("select_clause")

        # Get children of select_clause and the corresponding select keyword.
        child_segments = FunctionalContext(context).segment.children()
        select_keyword = child_segments[0]

        # See if we have a select_clause_modifier.
        select_clause_modifier_seg = child_segments.first(
            sp.is_type("select_clause_modifier"))

        # Rule doesn't apply if there's no select clause modifier.
        if not select_clause_modifier_seg:
            return None

        select_clause_modifier = select_clause_modifier_seg[0]

        # Are there any newlines between the select keyword
        # and the select clause modifier.
        leading_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_keyword,
        )

        # Rule doesn't apply if select clause modifier
        # is already on the same line as the select keyword.
        if not leading_newline_segments:
            return None

        # We should check if there is whitespace before the select clause modifier
        # and remove this during the lint fix.
        leading_whitespace_segments = child_segments.select(
            select_if=sp.is_type("whitespace"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_keyword,
        )

        # We should also check if the following select clause element
        # is on the same line as the select clause modifier.
        trailing_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_clause_modifier,
        )

        # We will insert these segments directly after the select keyword.
        edit_segments = [
            WhitespaceSegment(),
            select_clause_modifier,
        ]
        if not trailing_newline_segments:
            # if the first select clause element is on the same line
            # as the select clause modifier then also insert a newline.
            edit_segments.append(NewlineSegment())

        fixes = []
        # Move select clause modifier after select keyword.
        fixes.append(
            LintFix.create_after(
                anchor_segment=select_keyword,
                edit_segments=edit_segments,
            ))

        # Delete original newlines and whitespace between select keyword
        # and select clause modifier.

        # If there is not a newline after the select clause modifier then delete
        # newlines between the select keyword and select clause modifier.
        if not trailing_newline_segments:
            fixes.extend(LintFix.delete(s) for s in leading_newline_segments)
        # If there is a newline after the select clause modifier then delete both the
        # newlines and whitespace between the select keyword and select clause modifier.
        else:
            fixes.extend(
                LintFix.delete(s) for s in leading_newline_segments +
                leading_whitespace_segments)
        # Delete the original select clause modifier.
        fixes.append(LintFix.delete(select_clause_modifier))

        # If there is whitespace (on the same line) after the select clause modifier
        # then also delete this.
        trailing_whitespace_segments = child_segments.select(
            select_if=sp.is_whitespace(),
            loop_while=sp.or_(sp.is_type("whitespace"), sp.is_meta()),
            start_seg=select_clause_modifier,
        )
        if trailing_whitespace_segments:
            fixes.extend(
                (LintFix.delete(s) for s in trailing_whitespace_segments))

        return LintResult(
            anchor=context.segment,
            fixes=fixes,
        )
Esempio n. 30
0
    def _eval_single_select_target_element(
        self, select_targets_info, context: RuleContext
    ):
        select_clause = FunctionalContext(context).segment
        parent_stack = context.parent_stack

        if (
            select_targets_info.select_idx
            < select_targets_info.first_new_line_idx
            < select_targets_info.first_select_target_idx
        ):
            # Do we have a modifier?
            select_children = select_clause.children()
            modifier: Optional[Segments]
            modifier = select_children.first(sp.is_type("select_clause_modifier"))

            # Prepare the select clause which will be inserted
            insert_buff = [
                WhitespaceSegment(),
                select_children[select_targets_info.first_select_target_idx],
            ]

            # Check if the modifier is one we care about
            if modifier:
                # If it's already on the first line, ignore it.
                if (
                    select_children.index(modifier.get())
                    < select_targets_info.first_new_line_idx
                ):
                    modifier = None
            fixes = [
                # Delete the first select target from its original location.
                # We'll add it to the right section at the end, once we know
                # what to add.
                LintFix.delete(
                    select_children[select_targets_info.first_select_target_idx],
                ),
            ]

            # If we have a modifier to move:
            if modifier:

                # Add it to the insert
                insert_buff = [WhitespaceSegment(), modifier[0]] + insert_buff

                modifier_idx = select_children.index(modifier.get())
                # Delete the whitespace after it (which is two after, thanks to indent)
                if (
                    len(select_children) > modifier_idx + 1
                    and select_children[modifier_idx + 2].is_whitespace
                ):
                    fixes += [
                        LintFix.delete(
                            select_children[modifier_idx + 2],
                        ),
                    ]

                # Delete the modifier itself
                fixes += [
                    LintFix.delete(
                        modifier[0],
                    ),
                ]

                # Set the position marker for removing the preceding
                # whitespace and newline, which we'll use below.
                start_idx = modifier_idx
            else:
                # Set the position marker for removing the preceding
                # whitespace and newline, which we'll use below.
                start_idx = select_targets_info.first_select_target_idx

            if parent_stack and parent_stack[-1].is_type("select_statement"):
                select_stmt = parent_stack[-1]
                select_clause_idx = select_stmt.segments.index(select_clause.get())
                after_select_clause_idx = select_clause_idx + 1
                if len(select_stmt.segments) > after_select_clause_idx:

                    def _fixes_for_move_after_select_clause(
                        stop_seg: BaseSegment,
                        delete_segments: Optional[Segments] = None,
                        add_newline: bool = True,
                    ) -> List[LintFix]:
                        """Cleans up by moving leftover select_clause segments.

                        Context: Some of the other fixes we make in
                        _eval_single_select_target_element() leave leftover
                        child segments that need to be moved to become
                        *siblings* of the select_clause.
                        """
                        start_seg = (
                            modifier[0]
                            if modifier
                            else select_children[select_targets_info.first_new_line_idx]
                        )
                        move_after_select_clause = select_children.select(
                            start_seg=start_seg,
                            stop_seg=stop_seg,
                        )
                        # :TRICKY: Below, we have a couple places where we
                        # filter to guard against deleting the same segment
                        # multiple times -- this is illegal.
                        # :TRICKY: Use IdentitySet rather than set() since
                        # different segments may compare as equal.
                        all_deletes = IdentitySet(
                            fix.anchor for fix in fixes if fix.edit_type == "delete"
                        )
                        fixes_ = []
                        for seg in delete_segments or []:
                            if seg not in all_deletes:
                                fixes.append(LintFix.delete(seg))
                                all_deletes.add(seg)
                        fixes_ += [
                            LintFix.delete(seg)
                            for seg in move_after_select_clause
                            if seg not in all_deletes
                        ]
                        fixes_.append(
                            LintFix.create_after(
                                select_clause[0],
                                ([NewlineSegment()] if add_newline else [])
                                + list(move_after_select_clause),
                            )
                        )
                        return fixes_

                    if select_stmt.segments[after_select_clause_idx].is_type("newline"):
                        # Since we're deleting the newline, we should also delete all
                        # whitespace before it or it will add random whitespace to
                        # following statements. So walk back through the segment
                        # deleting whitespace until you get the previous newline, or
                        # something else.
                        to_delete = select_children.reversed().select(
                            loop_while=sp.is_type("whitespace"),
                            start_seg=select_children[start_idx],
                        )
                        if to_delete:
                            # The select_clause is immediately followed by a
                            # newline. Delete the newline in order to avoid leaving
                            # behind an empty line after fix, *unless* we stopped
                            # due to something other than a newline.
                            delete_last_newline = select_children[
                                start_idx - len(to_delete) - 1
                            ].is_type("newline")

                            # Delete the newline if we decided to.
                            if delete_last_newline:
                                fixes.append(
                                    LintFix.delete(
                                        select_stmt.segments[after_select_clause_idx],
                                    )
                                )

                            fixes += _fixes_for_move_after_select_clause(
                                to_delete[-1], to_delete
                            )
                    elif select_stmt.segments[after_select_clause_idx].is_type(
                        "whitespace"
                    ):
                        # The select_clause has stuff after (most likely a comment)
                        # Delete the whitespace immediately after the select clause
                        # so the other stuff aligns nicely based on where the select
                        # clause started.
                        fixes += [
                            LintFix.delete(
                                select_stmt.segments[after_select_clause_idx],
                            ),
                        ]
                        fixes += _fixes_for_move_after_select_clause(
                            select_children[
                                select_targets_info.first_select_target_idx
                            ],
                        )
                    elif select_stmt.segments[after_select_clause_idx].is_type(
                        "dedent"
                    ):
                        # Again let's strip back the whitespace, but simpler
                        # as don't need to worry about new line so just break
                        # if see non-whitespace
                        to_delete = select_children.reversed().select(
                            loop_while=sp.is_type("whitespace"),
                            start_seg=select_children[select_clause_idx - 1],
                        )
                        if to_delete:
                            fixes += _fixes_for_move_after_select_clause(
                                to_delete[-1],
                                to_delete,
                                # If we deleted a newline, create a newline.
                                any(seg for seg in to_delete if seg.is_type("newline")),
                            )
                    else:
                        fixes += _fixes_for_move_after_select_clause(
                            select_children[
                                select_targets_info.first_select_target_idx
                            ],
                        )

            if select_targets_info.comment_after_select_idx == -1:
                fixes += [
                    # Insert the select_clause in place of the first newline in the
                    # Select statement
                    LintFix.replace(
                        select_children[select_targets_info.first_new_line_idx],
                        insert_buff,
                    ),
                ]
            else:
                # The SELECT is followed by a comment on the same line. In order
                # to autofix this, we'd need to move the select target between
                # SELECT and the comment and potentially delete the entire line
                # where the select target was (if it is now empty). This is
                # *fairly tricky and complex*, in part because the newline on
                # the select target's line is several levels higher in the
                # parser tree. Hence, we currently don't autofix this. Could be
                # autofixed in the future if/when we have the time.
                fixes = []
            return LintResult(
                anchor=select_clause.get(),
                fixes=fixes,
            )
        return None