Ejemplo n.º 1
0
    def _get_indexes(context: RuleContext):
        children = context.functional.segment.children()
        select_targets = children.select(sp.is_type("select_clause_element"))
        first_select_target_idx = children.find(select_targets.get())
        selects = children.select(sp.is_keyword("select"))
        select_idx = children.find(selects.get())
        newlines = children.select(sp.is_type("newline"))
        first_new_line_idx = children.find(newlines.get())
        first_whitespace_idx = -1
        if first_new_line_idx != -1:
            # TRICKY: Ignore whitespace prior to the first newline, e.g. if
            # the line with "SELECT" (before any select targets) has trailing
            # whitespace.
            segments_after_first_line = children.select(
                sp.is_type("whitespace"),
                start_seg=children[first_new_line_idx])
            first_whitespace_idx = children.find(
                segments_after_first_line.get())

        siblings_post = context.functional.siblings_post
        from_segment = siblings_post.first(
            sp.is_type("from_clause")).first().get()
        pre_from_whitespace = siblings_post.select(sp.is_type("whitespace"),
                                                   stop_seg=from_segment)
        return SelectTargetsInfo(
            select_idx,
            first_new_line_idx,
            first_select_target_idx,
            first_whitespace_idx,
            select_targets,
            from_segment,
            list(pre_from_whitespace),
        )
Ejemplo n.º 2
0
    def _get_subsequent_whitespace(
        self,
        context,
    ) -> Tuple[Optional[BaseSegment], Optional[BaseSegment]]:
        """Search forwards through the raw segments for subsequent whitespace.

        Return a tuple of both the trailing whitespace segment and the
        first non-whitespace segment discovered.
        """
        # Get all raw segments. "raw_segments" is appropriate as the
        # only segments we can care about are comma, whitespace,
        # newline, and comment, which are all raw. Using the
        # raw_segments allows us to account for possible unexpected
        # parse tree structures resulting from other rule fixes.
        raw_segments = context.functional.raw_segments
        # Start after the current comma within the list. Get all the
        # following whitespace.
        following_segments = raw_segments.select(
            loop_while=sp.or_(sp.is_meta(), sp.is_type("whitespace")),
            start_seg=context.segment,
        )
        subsequent_whitespace = following_segments.last(sp.is_type("whitespace"))
        try:
            return (
                subsequent_whitespace[0] if subsequent_whitespace else None,
                raw_segments[
                    raw_segments.index(context.segment) + len(following_segments) + 1
                ],
            )
        except IndexError:
            # If we find ourselves here it's all whitespace (or nothing) to the
            # end of the file. This can only happen in bigquery (see
            # test_pass_bigquery_trailing_comma).
            return subsequent_whitespace, None
Ejemplo n.º 3
0
def _recursively_check_is_complex(
        select_clause_or_exp_children: Segments) -> bool:
    forgiveable_types = [
        "whitespace",
        "newline",
        "column_reference",
        "wildcard_expression",
        "cast_expression",
        "bracketed",
    ]
    selector = sp.not_(sp.is_type(*forgiveable_types))
    filtered = select_clause_or_exp_children.select(selector)
    remaining_count = len(filtered)

    # Once we have removed the above if nothing remains,
    # then this statement/expression was simple
    if remaining_count == 0:
        return False

    first_el = filtered.first()
    # Anything except a single expresion seg remains
    # Then it was complex
    if remaining_count > 1 or not first_el.all(sp.is_type("expression")):
        return True

    # If we have just an expression check if it was simple
    return _recursively_check_is_complex(first_el.children())
Ejemplo n.º 4
0
    def _eval(self, context: RuleContext) -> LintResult:
        """Function name not immediately followed by bracket.

        Look for Function Segment with anything other than the
        function name before brackets
        """
        segment = context.functional.segment
        # We only trigger on start_bracket (open parenthesis)
        if segment.all(sp.is_type("function")):
            children = segment.children()

            function_name = children.first(sp.is_type("function_name"))[0]
            start_bracket = children.first(sp.is_type("bracketed"))[0]

            intermediate_segments = children.select(start_seg=function_name,
                                                    stop_seg=start_bracket)
            if intermediate_segments:
                # It's only safe to fix if there is only whitespace
                # or newlines in the intervening section.
                if intermediate_segments.all(
                        sp.is_type("whitespace", "newline")):
                    return LintResult(
                        anchor=intermediate_segments[0],
                        fixes=[
                            LintFix.delete(seg)
                            for seg in intermediate_segments
                        ],
                    )
                else:
                    # It's not all whitespace, just report the error.
                    return LintResult(anchor=intermediate_segments[0], )

        return LintResult()
Ejemplo n.º 5
0
 def _eval(self, context: RuleContext) -> Optional[LintResult]:
     """Ambiguous use of DISTINCT in select statement with GROUP BY."""
     segment = context.functional.segment
     if (segment.all(sp.is_type("select_statement"))
             # Do we have a group by clause
             and segment.children(sp.is_type("groupby_clause"))):
         # Do we have the "DISTINCT" keyword in the select clause
         distinct = (segment.children(sp.is_type("select_clause")).children(
             sp.is_type("select_clause_modifier")).children(
                 sp.is_type("keyword")).select(sp.is_name("distinct")))
         if distinct:
             return LintResult(anchor=distinct[0])
     return None
Ejemplo n.º 6
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Trailing commas within select clause."""
        # Config type hints
        self.select_clause_trailing_comma: str

        segment = context.functional.segment
        children = segment.children()
        if segment.all(sp.is_type("select_clause")):
            # Iterate content to find last element
            last_content: BaseSegment = children.last(sp.is_code())[0]

            # What mode are we in?
            if self.select_clause_trailing_comma == "forbid":
                # Is it a comma?
                if last_content.is_type("comma"):
                    return LintResult(
                        anchor=last_content,
                        fixes=[LintFix.delete(last_content)],
                        description=
                        "Trailing comma in select statement forbidden",
                    )
            elif self.select_clause_trailing_comma == "require":
                if not last_content.is_type("comma"):
                    new_comma = SymbolSegment(",", name="comma", type="comma")
                    return LintResult(
                        anchor=last_content,
                        fixes=[
                            LintFix.replace(last_content,
                                            [last_content, new_comma])
                        ],
                        description=
                        "Trailing comma in select statement required",
                    )
        return None
Ejemplo n.º 7
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Find rule violations and provide fixes.

        0. Look for a case expression
        1. Look for "ELSE"
        2. Mark "ELSE" for deletion (populate "fixes")
        3. Backtrack and mark all newlines/whitespaces for deletion
        4. Look for a raw "NULL" segment
        5.a. The raw "NULL" segment is found, we mark it for deletion and return
        5.b. We reach the end of case when without matching "NULL": the rule passes
        """
        if context.segment.is_type("case_expression"):
            children = context.functional.segment.children()
            else_clause = children.first(sp.is_type("else_clause"))

            # Does the "ELSE" have a "NULL"? NOTE: Here, it's safe to look for
            # "NULL", as an expression would *contain* NULL but not be == NULL.
            if else_clause and else_clause.children(
                    lambda child: child.raw_upper == "NULL"):
                # Found ELSE with NULL. Delete the whole else clause as well as
                # indents/whitespaces/meta preceding the ELSE. :TRICKY: Note
                # the use of reversed() to make select() effectively search in
                # reverse.
                before_else = children.reversed().select(
                    start_seg=else_clause[0],
                    loop_while=sp.or_(sp.is_name("whitespace", "newline"),
                                      sp.is_meta()),
                )
                return LintResult(
                    anchor=context.segment,
                    fixes=[LintFix.delete(else_clause[0])] +
                    [LintFix.delete(seg) for seg in before_else],
                )
        return None
Ejemplo n.º 8
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Column expression without alias. Use explicit `AS` clause.

        We look for the select_clause_element segment, and then evaluate
        whether it has an alias segment or not and whether the expression
        is complicated enough. `parent_stack` is to assess how many other
        elements there are.

        """
        segment = context.functional.segment
        children = segment.children()
        # If we have an alias its all good
        if children.any(sp.is_type("alias_expression")):
            return None

        # If this is not a select_clause then this rule doesn't apply
        if not segment.all(sp.is_type("select_clause_element")):
            return None

        # Ignore if it's a function with EMITS clause as EMITS is equivalent to AS
        if (children.select(sp.is_type("function")).children().select(
                sp.is_type("emits_segment"))):
            return None

        select_clause_children = children.select(sp.not_(sp.is_name("star")))
        is_complex_clause = _recursively_check_is_complex(
            select_clause_children)
        if not is_complex_clause:
            return None
        # No fixes, because we don't know what the alias should be,
        # the user should document it themselves.
        if self.allow_scalar:  # type: ignore
            # Check *how many* elements/columns there are in the select
            # statement. If this is the only one, then we won't
            # report an error.
            immediate_parent = context.functional.parent_stack.last()
            elements = immediate_parent.children(
                sp.is_type("select_clause_element"))
            num_elements = len(elements)

            if num_elements > 1:
                return LintResult(anchor=context.segment)
            return None

        return LintResult(anchor=context.segment)
Ejemplo n.º 9
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Looking for DISTINCT before a bracket.

        Look for DISTINCT keyword immediately followed by open parenthesis.
        """
        # We trigger on `select_clause` and look for `select_clause_modifier`
        if context.segment.is_type("select_clause"):
            children = context.functional.segment.children()
            modifier = children.select(sp.is_type("select_clause_modifier"))
            first_element = children.select(
                sp.is_type("select_clause_element")).first()
            if not modifier or not first_element:
                return None
            # is the first element only an expression with only brackets?
            expression = (first_element.children(
                sp.is_type("expression")).first() or first_element)
            bracketed = expression.children(sp.is_type("bracketed")).first()
            if bracketed:
                fixes = []
                # If there's nothing else in the expression, remove the brackets.
                if len(expression[0].segments) == 1:
                    # Remove the brackets and strip any meta segments.
                    fixes.append(
                        LintFix.replace(
                            bracketed[0],
                            self.filter_meta(bracketed[0].segments)[1:-1]), )
                # If no whitespace between DISTINCT and expression, add it.
                if not children.select(sp.is_whitespace(),
                                       start_seg=modifier[0],
                                       stop_seg=first_element[0]):
                    fixes.append(
                        LintFix.create_before(
                            first_element[0],
                            [WhitespaceSegment()],
                        ))
                # If no fixes, no problem.
                if fixes:
                    return LintResult(anchor=modifier[0], fixes=fixes)
        return None
Ejemplo n.º 10
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Identify aliases in from clause and join conditions.

        Find base table, table expressions in join, and other expressions in select
        clause and decide if it's needed to report them.
        """
        if context.segment.is_type("select_statement"):
            children = context.functional.segment.children()
            from_clause_segment = children.select(
                sp.is_type("from_clause")).first()
            base_table = (from_clause_segment.children(
                sp.is_type("from_expression")).first().children(
                    sp.is_type("from_expression_element")).first().children(
                        sp.is_type("table_expression")).first().children(
                            sp.is_type("object_reference")).first())
            if not base_table:
                return None

            # A buffer for all table expressions in join conditions
            from_expression_elements = []
            column_reference_segments = []

            after_from_clause = children.select(
                start_seg=from_clause_segment[0])
            for clause in from_clause_segment + after_from_clause:
                for from_expression_element in clause.recursive_crawl(
                        "from_expression_element"):
                    from_expression_elements.append(from_expression_element)
                for column_reference in clause.recursive_crawl(
                        "column_reference"):
                    column_reference_segments.append(column_reference)

            return (self._lint_aliases_in_join(
                base_table[0] if base_table else None,
                from_expression_elements,
                column_reference_segments,
                context.segment,
            ) or None)
        return None
Ejemplo n.º 11
0
    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Operators should follow a standard for being before/after newlines.

        We use the memory to keep track of whitespace up to now, and
        whether the last code segment was an operator or not.
        Anchor is our signal as to whether there's a problem.

        We only trigger if we have an operator FOLLOWED BY a newline
        before the next meaningful code segment.

        """
        relevent_types = ["binary_operator", "comparison_operator"]
        segment = context.functional.segment
        # bring var to this scope so as to only have one type ignore
        operator_new_lines: str = self.operator_new_lines  # type: ignore
        expr = segment.children()
        operator_segments = segment.children(sp.is_type(*relevent_types))
        results: List[LintResult] = []
        # If len(operator_segments) == 0 this will essentially not run
        for operator in operator_segments:
            start = expr.reversed().select(start_seg=operator).first(
                sp.is_code())
            end = expr.select(start_seg=operator).first(sp.is_code())
            res = [
                expr.select(start_seg=start.get(), stop_seg=operator),
                expr.select(start_seg=operator, stop_seg=end.get()),
            ]
            # anchor and change els are reversed in the before case
            if operator_new_lines == "before":
                res = [els.reversed() for els in reversed(res)]

            change_list, anchor_list = res
            # If the anchor side of the list has no newline
            # then everything is ok already
            if not anchor_list.any(sp.is_name("newline")):
                continue

            insert_anchor = anchor_list.last().get()
            assert insert_anchor, "Insert Anchor must be present"
            lint_res = _generate_fixes(
                operator_new_lines,
                change_list,
                operator,
                insert_anchor,
            )
            results.append(lint_res)

        if len(results) == 0:
            return None
        return results
Ejemplo n.º 12
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Join/From clauses should not contain subqueries. Use CTEs instead.

        NB: No fix for this routine because it would be very complex to
        implement reliably.
        """
        parent_types = self._config_mapping[self.forbid_subquery_in]  # type: ignore
        for parent_type in parent_types:
            if context.segment.is_type(parent_type):
                # Get the referenced table segment
                from_expression_element = context.functional.segment.children(
                    is_type("from_expression_element")
                ).children(is_type("table_expression"))

                # Is it bracketed? If so, lint that instead.
                bracketed_expression = from_expression_element.children(
                    is_type("bracketed")
                )
                if bracketed_expression:
                    from_expression_element = bracketed_expression

                # If we find a child with a "problem" type, raise an issue.
                # If not, we're fine.
                seg = from_expression_element.children(
                    is_type(
                        "with_compound_statement",
                        "set_expression",
                        "select_statement",
                    )
                )
                if seg:
                    return LintResult(
                        anchor=seg[0],
                        description=f"{parent_type} clauses should not contain "
                        "subqueries. Use CTEs instead",
                    )
        return None
Ejemplo n.º 13
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Files must not begin with newlines or whitespace."""
        # If parent_stack is empty we are currently at FileSegment.
        if len(context.parent_stack) == 0:
            return None

        # If raw_stack is empty there can be nothing to remove.
        if len(context.raw_stack) == 0:
            return None

        segment = context.functional.segment
        raw_stack = context.functional.raw_stack
        whitespace_types = {"newline", "whitespace", "indent", "dedent"}
        # Non-whitespace segment.
        if (
            # Non-whitespace segment.
            not segment.all(sp.is_type(*whitespace_types))
            # We want first Non-whitespace segment so
            # all preceding segments must be whitespace
            # and at least one is not meta.
            and raw_stack.all(sp.is_type(*whitespace_types))
            and not raw_stack.all(sp.is_meta())
            # Found leaf of parse tree.
            and not segment.all(sp.is_expandable())
            # It is possible that a template segment (e.g.
            # {{ config(materialized='view') }}) renders to an empty string and as such
            # is omitted from the parsed tree. We therefore should flag if a templated
            # raw slice intersects with the source slices in the raw stack and skip this
            # rule to avoid risking collisions with template objects.
            and not raw_stack.raw_slices.any(rsp.is_slice_type("templated"))
        ):
            return LintResult(
                anchor=context.parent_stack[0],
                fixes=[LintFix.delete(d) for d in raw_stack],
            )
        return None
Ejemplo n.º 14
0
    def _eval(self, context: RuleContext) -> LintResult:
        """Unnecessary trailing whitespace.

        Look for newline segments, and then evaluate what
        it was preceded by.
        """
        # We only trigger on newlines
        if (context.segment.is_type("newline") and len(context.raw_stack) > 0
                and context.raw_stack[-1].is_type("whitespace")):
            # If we find a newline, which is preceded by whitespace, then bad
            deletions = context.functional.raw_stack.reversed().select(
                loop_while=sp.is_type("whitespace"))
            last_deletion_slice = deletions[-1].pos_marker.source_slice

            # Check the raw source (before template expansion) immediately
            # following the whitespace we want to delete. Often, what looks
            # like trailing whitespace in rendered SQL is actually a line like:
            # "    {% for elem in elements %}\n", in which case the code is
            # fine -- it's not trailing whitespace from a source code
            # perspective.
            if context.templated_file:
                next_raw_slice = (
                    context.templated_file.raw_slices_spanning_source_slice(
                        slice(last_deletion_slice.stop,
                              last_deletion_slice.stop)))
                # If the next slice is literal, that means it's regular code, so
                # it's safe to delete the trailing whitespace. If it's anything
                # else, it's template code, so don't delete the whitespace because
                # it's not REALLY trailing whitespace in terms of the raw source
                # code.
                if next_raw_slice and next_raw_slice[0].slice_type != "literal":
                    return LintResult()
            return LintResult(
                anchor=deletions[-1],
                fixes=[LintFix.delete(d) for d in deletions],
            )
        return LintResult()
Ejemplo n.º 15
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Find rule violations and provide fixes."""
        # Config type hints
        self.prefer_count_0: bool
        self.prefer_count_1: bool

        if (context.segment.is_type("function")
                and context.segment.get_child("function_name").raw_upper
                == "COUNT"):
            # Get bracketed content
            f_content = context.functional.segment.children(
                sp.is_type("bracketed")).children(
                    sp.and_(
                        sp.not_(sp.is_meta()),
                        sp.not_(
                            sp.is_type("start_bracket", "end_bracket",
                                       "whitespace", "newline")),
                    ))
            if len(f_content) != 1:  # pragma: no cover
                return None

            preferred = "*"
            if self.prefer_count_1:
                preferred = "1"
            elif self.prefer_count_0:
                preferred = "0"

            if f_content[0].is_type("star") and (self.prefer_count_1
                                                 or self.prefer_count_0):
                return LintResult(
                    anchor=context.segment,
                    fixes=[
                        LintFix.replace(
                            f_content[0],
                            [
                                f_content[0].edit(f_content[0].raw.replace(
                                    "*", preferred))
                            ],
                        ),
                    ],
                )

            if f_content[0].is_type("expression"):
                expression_content = [
                    seg for seg in f_content[0].segments if not seg.is_meta
                ]

                if (len(expression_content) == 1
                        and expression_content[0].is_type("literal")
                        and expression_content[0].raw in ["0", "1"]
                        and expression_content[0].raw != preferred):
                    return LintResult(
                        anchor=context.segment,
                        fixes=[
                            LintFix.replace(
                                expression_content[0],
                                [
                                    expression_content[0].edit(
                                        expression_content[0].raw.replace(
                                            expression_content[0].raw,
                                            preferred)),
                                ],
                            ),
                        ],
                    )
        return None
Ejemplo n.º 16
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Inconsistent column references in GROUP BY/ORDER BY clauses."""
        # Config type hints
        self.group_by_and_order_by_style: str

        # We only care about GROUP BY/ORDER BY clauses.
        if not context.segment.is_type("groupby_clause", "orderby_clause"):
            return None

        # Ignore Windowing clauses
        if context.functional.parent_stack.any(sp.is_type(*self._ignore_types)):
            return LintResult(memory=context.memory)

        # Look at child segments and map column references to either the implict or
        # explicit category.
        # N.B. segment names are used as the numeric literal type is 'raw', so best to
        # be specific with the name.
        column_reference_category_map = {
            "ColumnReferenceSegment": "explicit",
            "ExpressionSegment": "explicit",
            "numeric_literal": "implicit",
        }
        column_reference_category_set = {
            column_reference_category_map[segment.name]
            for segment in context.segment.segments
            if segment.name in column_reference_category_map
        }

        # If there are no column references then just return
        if not column_reference_category_set:
            return LintResult(memory=context.memory)

        if self.group_by_and_order_by_style == "consistent":
            # If consistent naming then raise lint error if either:

            if len(column_reference_category_set) > 1:
                # 1. Both implicit and explicit column references are found in the same
                # clause.
                return LintResult(
                    anchor=context.segment,
                    memory=context.memory,
                )
            else:
                # 2. A clause is found to contain column name references that
                #    contradict the precedent set in earlier clauses.
                current_group_by_order_by_convention = (
                    column_reference_category_set.pop()
                )
                prior_group_by_order_by_convention = context.memory.get(
                    "prior_group_by_order_by_convention"
                )

                if prior_group_by_order_by_convention and (
                    prior_group_by_order_by_convention
                    != current_group_by_order_by_convention
                ):
                    return LintResult(
                        anchor=context.segment,
                        memory=context.memory,
                    )

                context.memory[
                    "prior_group_by_order_by_convention"
                ] = current_group_by_order_by_convention
        else:
            # If explicit or implicit naming then raise lint error
            # if the opposite reference type is detected.
            if any(
                category != self.group_by_and_order_by_style
                for category in column_reference_category_set
            ):
                return LintResult(
                    anchor=context.segment,
                    memory=context.memory,
                )

        # Return memory for later clauses.
        return LintResult(memory=context.memory)
Ejemplo n.º 17
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Unnecessary quoted identifier."""
        # Config type hints
        self.prefer_quoted_identifiers: bool
        self.ignore_words: str

        # Ignore some segment types
        if context.functional.parent_stack.any(
                sp.is_type(*self._ignore_types)):
            return None

        if self.prefer_quoted_identifiers:
            context_policy = "naked_identifier"
            identifier_contents = context.segment.raw
        else:
            context_policy = "quoted_identifier"
            identifier_contents = context.segment.raw[1:-1]

        # Get the ignore_words_list configuration.
        try:
            ignore_words_list = self.ignore_words_list
        except AttributeError:
            # First-time only, read the settings from configuration. This is
            # very slow.
            ignore_words_list = self._init_ignore_words_list()

        # Skip if in ignore list
        if ignore_words_list and identifier_contents.lower(
        ) in ignore_words_list:
            return None

        # Ignore the segments that are not of the same type as the defined policy above.
        # Also TSQL has a keyword called QUOTED_IDENTIFIER which maps to the name so
        # need to explicity check for that.
        if (context.segment.name not in context_policy
                or context.segment.raw.lower()
                in ("quoted_identifier", "naked_identifier")):
            return None

        # Manage cases of identifiers must be quoted first.
        # Naked identifiers are _de facto_ making this rule fail as configuration forces
        # them to be quoted.
        # In this case, it cannot be fixed as which quote to use is dialect dependent
        if self.prefer_quoted_identifiers:
            return LintResult(
                context.segment,
                description=f"Missing quoted identifier {identifier_contents}.",
            )

        # Now we only deal with NOT forced quoted identifiers configuration
        # (meaning prefer_quoted_identifiers=False).

        # Extract contents of outer quotes.
        quoted_identifier_contents = context.segment.raw[1:-1]

        # Retrieve NakedIdentifierSegment RegexParser for the dialect.
        naked_identifier_parser = context.dialect._library[
            "NakedIdentifierSegment"]

        # Check if quoted_identifier_contents could be a valid naked identifier
        # and that it is not a reserved keyword.
        if (regex.fullmatch(
                naked_identifier_parser.template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is not None) and (regex.fullmatch(
                naked_identifier_parser.anti_template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is None):
            return LintResult(
                context.segment,
                fixes=[
                    LintFix.replace(
                        context.segment,
                        [
                            CodeSegment(
                                raw=quoted_identifier_contents,
                                name="naked_identifier",
                                type="identifier",
                            )
                        ],
                    )
                ],
                description=
                f"Unnecessary quoted identifier {context.segment.raw}.",
            )

        return None
Ejemplo n.º 18
0
    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Look for USING in a join clause."""
        segment = context.functional.segment
        parent_stack = context.functional.parent_stack
        # We are not concerned with non join clauses
        if not segment.all(sp.is_type("join_clause")):
            return None

        using_anchor = segment.children(sp.is_keyword("using")).first()
        # If there is no evidence of a USING then we exit
        if len(using_anchor) == 0:
            return None

        anchor = using_anchor.get()
        description = "Found USING statement. Expected only ON statements."
        # All returns from here out will be some form of linting error.
        # we prepare the variable here
        unfixable_result = LintResult(
            anchor=anchor,
            description=description,
        )

        tables_in_join = parent_stack.last().children(
            sp.is_type("join_clause", "from_expression_element"))

        # If we have more than 2 tables we won't try to fix join.
        # TODO: if this is table 2 of 3 it is still fixable
        if len(tables_in_join) > 2:
            return unfixable_result

        parent_select = parent_stack.last(sp.is_type("select_statement")).get()
        if not parent_select:  # pragma: no cover
            return unfixable_result

        select_info = get_select_statement_info(parent_select, context.dialect)
        if not select_info:  # pragma: no cover
            return unfixable_result

        to_delete, insert_after_anchor = _extract_deletion_sequence_and_anchor(
            tables_in_join.last())
        table_a, table_b = select_info.table_aliases
        edit_segments = [
            KeywordSegment(raw="ON"),
            WhitespaceSegment(raw=" "),
        ] + _generate_join_conditions(
            table_a.ref_str,
            table_b.ref_str,
            select_info.using_cols,
        )

        fixes = [
            LintFix.create_before(
                anchor_segment=insert_after_anchor,
                edit_segments=edit_segments,
            ),
            *[LintFix.delete(seg) for seg in to_delete],
        ]
        return LintResult(
            anchor=anchor,
            description=description,
            fixes=fixes,
        )
Ejemplo n.º 19
0
    def _eval_single_select_target_element(self, select_targets_info,
                                           context: RuleContext):
        select_clause = context.functional.segment
        parent_stack = context.parent_stack
        wildcards = select_clause.children(
            sp.is_type("select_clause_element")).children(
                sp.is_type("wildcard_expression"))
        is_wildcard = bool(wildcards)
        if is_wildcard:
            wildcard_select_clause_element = wildcards[0]

        if (select_targets_info.select_idx <
                select_targets_info.first_new_line_idx <
                select_targets_info.first_select_target_idx) and (
                    not is_wildcard):
            # Do we have a modifier?
            select_children = select_clause.children()
            modifier: Optional[Segments]
            modifier = select_children.first(
                sp.is_type("select_clause_modifier"))

            # Prepare the select clause which will be inserted
            insert_buff = [
                WhitespaceSegment(),
                select_children[select_targets_info.first_select_target_idx],
            ]

            # Check if the modifier is one we care about
            if modifier:
                # If it's already on the first line, ignore it.
                if (select_children.index(modifier.get()) <
                        select_targets_info.first_new_line_idx):
                    modifier = None
            fixes = [
                # Delete the first select target from its original location.
                # We'll add it to the right section at the end, once we know
                # what to add.
                LintFix.delete(
                    select_children[
                        select_targets_info.first_select_target_idx], ),
            ]

            # If we have a modifier to move:
            if modifier:

                # Add it to the insert
                insert_buff = [WhitespaceSegment(), modifier[0]] + insert_buff

                modifier_idx = select_children.index(modifier.get())
                # Delete the whitespace after it (which is two after, thanks to indent)
                if (len(select_children) > modifier_idx + 1
                        and select_children[modifier_idx + 2].is_whitespace):
                    fixes += [
                        LintFix.delete(select_children[modifier_idx + 2], ),
                    ]

                # Delete the modifier itself
                fixes += [
                    LintFix.delete(modifier[0], ),
                ]

                # Set the position marker for removing the preceding
                # whitespace and newline, which we'll use below.
                start_idx = modifier_idx
            else:
                # Set the position marker for removing the preceding
                # whitespace and newline, which we'll use below.
                start_idx = select_targets_info.first_select_target_idx

            if parent_stack and parent_stack[-1].is_type("select_statement"):
                select_stmt = parent_stack[-1]
                select_clause_idx = select_stmt.segments.index(
                    select_clause.get())
                after_select_clause_idx = select_clause_idx + 1
                if len(select_stmt.segments) > after_select_clause_idx:

                    def _fixes_for_move_after_select_clause(
                        stop_seg: BaseSegment,
                        add_newline: bool = True,
                    ) -> List[LintFix]:
                        """Cleans up by moving leftover select_clause segments.

                        Context: Some of the other fixes we make in
                        _eval_single_select_target_element() leave the
                        select_clause in an illegal state -- a select_clause's
                        *rightmost children cannot be whitespace or comments*.
                        This function addresses that by moving these segments
                        up the parse tree to an ancestor segment chosen by
                        _choose_anchor_segment(). After these fixes are applied,
                        these segments may, for example, be *siblings* of
                        select_clause.
                        """
                        start_seg = (
                            modifier[0] if modifier else select_children[
                                select_targets_info.first_new_line_idx])
                        move_after_select_clause = select_children.select(
                            start_seg=start_seg,
                            stop_seg=stop_seg,
                        )
                        fixes = [
                            LintFix.delete(seg)
                            for seg in move_after_select_clause
                        ]
                        fixes.append(
                            LintFix.create_after(
                                self._choose_anchor_segment(
                                    context,
                                    "create_after",
                                    select_clause[0],
                                    filter_meta=True,
                                ),
                                ([NewlineSegment()] if add_newline else []) +
                                list(move_after_select_clause),
                            ))
                        return fixes

                    if select_stmt.segments[after_select_clause_idx].is_type(
                            "newline"):
                        # Since we're deleting the newline, we should also delete all
                        # whitespace before it or it will add random whitespace to
                        # following statements. So walk back through the segment
                        # deleting whitespace until you get the previous newline, or
                        # something else.
                        to_delete = select_children.reversed().select(
                            loop_while=sp.is_type("whitespace"),
                            start_seg=select_children[start_idx],
                        )
                        fixes += [LintFix.delete(seg) for seg in to_delete]

                        # The select_clause is immediately followed by a
                        # newline. Delete the newline in order to avoid leaving
                        # behind an empty line after fix, *unless* we stopped
                        # due to something other than a newline.
                        delete_last_newline = select_children[
                            start_idx - len(to_delete) - 1].is_type("newline")

                        # Delete the newline if we decided to.
                        if delete_last_newline:
                            fixes.append(
                                LintFix.delete(
                                    select_stmt.
                                    segments[after_select_clause_idx], ))

                        fixes += _fixes_for_move_after_select_clause(
                            to_delete[-1], )
                    elif select_stmt.segments[after_select_clause_idx].is_type(
                            "whitespace"):
                        # The select_clause has stuff after (most likely a comment)
                        # Delete the whitespace immediately after the select clause
                        # so the other stuff aligns nicely based on where the select
                        # clause started
                        fixes += [
                            LintFix.delete(
                                select_stmt.segments[after_select_clause_idx],
                            ),
                        ]
                        fixes += _fixes_for_move_after_select_clause(
                            select_children[
                                select_targets_info.first_select_target_idx], )
                    elif select_stmt.segments[after_select_clause_idx].is_type(
                            "dedent"):
                        # Again let's strip back the whitespace, but simpler
                        # as don't need to worry about new line so just break
                        # if see non-whitespace
                        to_delete = select_children.reversed().select(
                            loop_while=sp.is_type("whitespace"),
                            start_seg=select_children[select_clause_idx - 1],
                        )
                        fixes += [LintFix.delete(seg) for seg in to_delete]

                        if to_delete:
                            fixes += _fixes_for_move_after_select_clause(
                                to_delete[-1],
                                # If we deleted a newline, create a newline.
                                any(seg for seg in to_delete
                                    if seg.is_type("newline")),
                            )
                    else:
                        fixes += _fixes_for_move_after_select_clause(
                            select_children[
                                select_targets_info.first_select_target_idx], )

            fixes += [
                # Insert the select_clause in place of the first newline in the
                # Select statement
                LintFix.replace(
                    select_children[select_targets_info.first_new_line_idx],
                    insert_buff,
                ),
            ]

            return LintResult(
                anchor=select_clause.get(),
                fixes=fixes,
            )

        # If we have a wildcard on the same line as the FROM keyword, but not the same
        # line as the SELECT keyword, we need to move the FROM keyword to its own line.
        # i.e.
        # SELECT
        #   * FROM foo
        if select_targets_info.from_segment:
            if (is_wildcard and
                (select_clause[0].pos_marker.working_line_no !=
                 select_targets_info.from_segment.pos_marker.working_line_no)
                    and
                (wildcard_select_clause_element.pos_marker.working_line_no ==
                 select_targets_info.from_segment.pos_marker.working_line_no)):
                fixes = [
                    LintFix.delete(ws)
                    for ws in select_targets_info.pre_from_whitespace
                ]
                fixes.append(
                    LintFix.create_before(
                        select_targets_info.from_segment,
                        [NewlineSegment()],
                    ))
                return LintResult(anchor=select_clause.get(), fixes=fixes)

        return None