Пример #1
0
class Rule_L024(Rule_L023):
    """Single whitespace expected after ``USING`` in ``JOIN`` clause.

    **Anti-pattern**

    .. code-block:: sql

        SELECT b
        FROM foo
        LEFT JOIN zoo USING(a)

    **Best practice**

    Add a space after ``USING``, to avoid confusing it
    for a function.

    .. code-block:: sql
       :force:

        SELECT b
        FROM foo
        LEFT JOIN zoo USING (a)
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"join_clause"})
    pre_segment_identifier = ("raw_upper", "USING")
    post_segment_identifier = ("type", "bracketed")
    expand_children = None
    allow_newline = True
Пример #2
0
class Rule_L061(BaseRule):
    """Use ``!=`` instead of ``<>`` for "not equal to" comparisons.

    **Anti-pattern**

    ``<>`` means ``not equal`` but doesn't sound like this when we say it out loud.

    .. code-block:: sql

        SELECT * FROM X WHERE 1 <> 2;

    **Best practice**

    Use ``!=`` instead because its sounds more natural and is more common in other
    programming languages.

    .. code-block:: sql

        SELECT * FROM X WHERE 1 != 2;

    """

    groups = ("all",)
    crawl_behaviour = SegmentSeekerCrawler({"comparison_operator"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Use ``!=`` instead of ``<>`` for "not equal to" comparison."""
        # Only care about not_equal_to segments. We should only get
        # comparison operator types from the crawler, but not all will
        # be "not_equal_to".
        if context.segment.name != "not_equal_to":
            return None

        # Get the comparison operator children
        raw_comparison_operators = (
            FunctionalContext(context)
            .segment.children()
            .select(select_if=sp.is_type("raw_comparison_operator"))
        )

        # Only care about ``<>``
        if [r.raw for r in raw_comparison_operators] != ["<", ">"]:
            return None

        # Provide a fix and replace ``<>`` with ``!=``
        # As each symbol is a separate symbol this is done in two steps:
        # 1. Replace < with !
        # 2. Replace > with =
        fixes = [
            LintFix.replace(
                raw_comparison_operators[0],
                [SymbolSegment(raw="!", type="raw_comparison_operator")],
            ),
            LintFix.replace(
                raw_comparison_operators[1],
                [SymbolSegment(raw="=", type="raw_comparison_operator")],
            ),
        ]

        return LintResult(context.segment, fixes)
Пример #3
0
class Rule_L060(BaseRule):
    """Use ``COALESCE`` instead of ``IFNULL`` or ``NVL``.

    **Anti-pattern**

    ``IFNULL`` or ``NVL`` are used to fill ``NULL`` values.

    .. code-block:: sql

        SELECT ifnull(foo, 0) AS bar,
        FROM baz;

        SELECT nvl(foo, 0) AS bar,
        FROM baz;

    **Best practice**

    Use ``COALESCE`` instead.
    ``COALESCE`` is universally supported,
    whereas Redshift doesn't support ``IFNULL``
    and BigQuery doesn't support ``NVL``.
    Additionally, ``COALESCE`` is more flexible
    and accepts an arbitrary number of arguments.

    .. code-block:: sql

        SELECT coalesce(foo, 0) AS bar,
        FROM baz;

    """

    groups = ("all", )
    crawl_behaviour = SegmentSeekerCrawler({"function_name_identifier"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Use ``COALESCE`` instead of ``IFNULL`` or ``NVL``."""
        # We only care about function names, and they should be the
        # only things we get
        assert context.segment.is_type("function_name_identifier")

        # Only care if the function is ``IFNULL`` or ``NVL``.
        if context.segment.raw_upper not in {"IFNULL", "NVL"}:
            return None

        # Create fix to replace ``IFNULL`` or ``NVL`` with ``COALESCE``.
        fix = LintFix.replace(
            context.segment,
            [CodeSegment(
                raw="COALESCE",
                type="function_name_identifier",
            )],
        )

        return LintResult(
            anchor=context.segment,
            fixes=[fix],
            description=
            f"Use 'COALESCE' instead of '{context.segment.raw_upper}'.",
        )
Пример #4
0
class Rule_L017(BaseRule):
    """Function name not immediately followed by parenthesis.

    **Anti-pattern**

    In this example, there is a space between the function and the parenthesis.

    .. code-block:: sql

        SELECT
            sum (a)
        FROM foo

    **Best practice**

    Remove the space between the function and the parenthesis.

    .. code-block:: sql

        SELECT
            sum(a)
        FROM foo

    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"function"})

    def _eval(self, context: RuleContext) -> LintResult:
        """Function name not immediately followed by bracket.

        Look for Function Segment with anything other than the
        function name before brackets
        """
        segment = FunctionalContext(context).segment
        # We only trigger on start_bracket (open parenthesis)
        assert segment.all(sp.is_type("function"))
        children = segment.children()

        function_name = children.first(sp.is_type("function_name"))[0]
        start_bracket = children.first(sp.is_type("bracketed"))[0]

        intermediate_segments = children.select(start_seg=function_name,
                                                stop_seg=start_bracket)
        if intermediate_segments:
            # It's only safe to fix if there is only whitespace
            # or newlines in the intervening section.
            if intermediate_segments.all(sp.is_type("whitespace", "newline")):
                return LintResult(
                    anchor=intermediate_segments[0],
                    fixes=[
                        LintFix.delete(seg) for seg in intermediate_segments
                    ],
                )
            else:
                # It's not all whitespace, just report the error.
                return LintResult(anchor=intermediate_segments[0], )
        return LintResult()
Пример #5
0
class Rule_L005(BaseRule):
    """Commas should not have whitespace directly before them.

    Unless it's an indent. Trailing/leading commas are dealt with
    in a different rule.

    **Anti-pattern**

    The ``•`` character represents a space.
    There is an extra space in line two before the comma.

    .. code-block:: sql
       :force:

        SELECT
            a•,
            b
        FROM foo

    **Best practice**

    Remove the space before the comma.

    .. code-block:: sql

        SELECT
            a,
            b
        FROM foo
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"comma"}, provide_raw_stack=True)

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Commas should not have whitespace directly before them."""
        if not context.raw_stack:
            return None  # pragma: no cover
        anchor: Optional[RawSegment] = context.raw_stack[-1]
        if (
            # We need at least one segment previous segment for this to work.
            anchor is not None
            and context.segment.is_type("comma")
            and anchor.is_type("whitespace")
            and anchor.pos_marker.line_pos > 1
        ):
            return LintResult(anchor=anchor, fixes=[LintFix.delete(anchor)])
        # Otherwise fine.
        return None
Пример #6
0
class Rule_Example_L001(BaseRule):
    """ORDER BY on these columns is forbidden!

    **Anti-pattern**

    Using ``ORDER BY`` one some forbidden columns.

    .. code-block:: sql

        SELECT *
        FROM foo
        ORDER BY
            bar,
            baz

    **Best practice**

    Do not order by these columns.

    .. code-block:: sql

        SELECT *
        FROM foo
        ORDER BY bar
    """

    groups = ("all", )
    config_keywords = ["forbidden_columns"]
    crawl_behaviour = SegmentSeekerCrawler({"column_reference"})

    def __init__(self, *args, **kwargs):
        """Overwrite __init__ to set config."""
        super().__init__(*args, **kwargs)
        self.forbidden_columns = [
            col.strip() for col in self.forbidden_columns.split(",")
        ]

    def _eval(self, context: RuleContext):
        """We should not use ORDER BY."""
        if context.segment.is_type("orderby_clause"):
            for seg in context.segment.segments:
                col_name = seg.raw.lower()
                if col_name in self.forbidden_columns:
                    return LintResult(
                        anchor=seg,
                        description=
                        f"Column `{col_name}` not allowed in ORDER BY.",
                    )
Пример #7
0
class Rule_L001(BaseRule):
    """Unnecessary trailing whitespace.

    **Anti-pattern**

    The ``•`` character represents a space.

    .. code-block:: sql
       :force:

        SELECT
            a
        FROM foo••

    **Best practice**

    Remove trailing spaces.

    .. code-block:: sql

        SELECT
            a
        FROM foo
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"newline", "end_of_file"},
                                           provide_raw_stack=True)

    def _eval(self, context: RuleContext) -> LintResult:
        """Unnecessary trailing whitespace.

        Look for newline segments, and then evaluate what
        it was preceded by.
        """
        if len(context.raw_stack) > 0 and context.raw_stack[-1].is_type(
                "whitespace"):
            # Look for a newline (or file end), which is preceded by whitespace
            deletions = (
                FunctionalContext(context).raw_stack.reversed().select(
                    loop_while=sp.is_type("whitespace")))
            # NOTE: The presence of a loop marker should prevent false
            # flagging of newlines before jinja loop tags.
            return LintResult(
                anchor=deletions[-1],
                fixes=[LintFix.delete(d) for d in deletions],
            )
        return LintResult()
Пример #8
0
class Rule_L055(BaseRule):
    """Use ``LEFT JOIN`` instead of ``RIGHT JOIN``.

    **Anti-pattern**

    ``RIGHT JOIN`` is used.

    .. code-block:: sql
       :force:

        SELECT
            foo.col1,
            bar.col2
        FROM foo
        RIGHT JOIN bar
            ON foo.bar_id = bar.id;

    **Best practice**

    Refactor and use ``LEFT JOIN`` instead.

    .. code-block:: sql
       :force:

        SELECT
            foo.col1,
            bar.col2
        FROM bar
        LEFT JOIN foo
            ON foo.bar_id = bar.id;
    """

    groups = ("all",)
    crawl_behaviour = SegmentSeekerCrawler({"join_clause"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Use LEFT JOIN instead of RIGHT JOIN."""
        # We are only interested in JOIN clauses.
        assert context.segment.is_type("join_clause")

        # Identify if RIGHT JOIN is present.
        if {"right", "join"}.issubset(
            {segment.name for segment in context.segment.segments}
        ):
            return LintResult(context.segment.segments[0])

        return None
Пример #9
0
class Rule_L040(Rule_L010):
    """Inconsistent capitalisation of boolean/null literal.

    **Anti-pattern**

    In this example, ``null`` and ``false`` are in lower-case whereas ``TRUE`` is in
    upper-case.

    .. code-block:: sql

        select
            a,
            null,
            TRUE,
            false
        from foo

    **Best practice**

    Ensure all literal ``null``/``true``/``false`` literals are consistently
    upper or lower case

    .. code-block:: sql

        select
            a,
            NULL,
            TRUE,
            FALSE
        from foo

        -- Also good

        select
            a,
            null,
            true,
            false
        from foo

    """

    groups = ("all", "core")
    lint_phase = "post"
    crawl_behaviour = SegmentSeekerCrawler({"null_literal", "boolean_literal"})
    _exclude_elements: List[Tuple[str, str]] = []
    _description_elem = "Boolean/null literals"
Пример #10
0
class Rule_L021(BaseRule):
    """Ambiguous use of ``DISTINCT`` in a ``SELECT`` statement with ``GROUP BY``.

    When using ``GROUP BY`` a `DISTINCT`` clause should not be necessary as every
    non-distinct ``SELECT`` clause must be included in the ``GROUP BY`` clause.

    **Anti-pattern**

    ``DISTINCT`` and ``GROUP BY`` are conflicting.

    .. code-block:: sql

        SELECT DISTINCT
            a
        FROM foo
        GROUP BY a

    **Best practice**

    Remove ``DISTINCT`` or ``GROUP BY``. In our case, removing ``GROUP BY`` is better.

    .. code-block:: sql

        SELECT DISTINCT
            a
        FROM foo
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"select_statement"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Ambiguous use of DISTINCT in select statement with GROUP BY."""
        segment = FunctionalContext(context).segment
        # We know it's a select_statement from the seeker crawler
        assert segment.all(sp.is_type("select_statement"))
        # Do we have a group by clause
        if segment.children(sp.is_type("groupby_clause")):
            # Do we have the "DISTINCT" keyword in the select clause
            distinct = (segment.children(sp.is_type("select_clause")).children(
                sp.is_type("select_clause_modifier")).children(
                    sp.is_type("keyword")).select(sp.is_name("distinct")))
            if distinct:
                return LintResult(anchor=distinct[0])
        return None
Пример #11
0
class Rule_L030(Rule_L010):
    """Inconsistent capitalisation of function names.

    **Anti-pattern**

    In this example, the two ``SUM`` functions don't have the same capitalisation.

    .. code-block:: sql

        SELECT
            sum(a) AS aa,
            SUM(b) AS bb
        FROM foo

    **Best practice**

    Make the case consistent.

    .. code-block:: sql

        SELECT
            sum(a) AS aa,
            sum(b) AS bb
        FROM foo

    """

    groups = ("all", "core")
    lint_phase = "post"
    crawl_behaviour = SegmentSeekerCrawler(
        {"function_name_identifier", "bare_function"}
    )
    _exclude_elements: List[Tuple[str, str]] = []
    config_keywords = [
        "extended_capitalisation_policy",
        "ignore_words",
        "ignore_words_regex",
    ]
    _description_elem = "Function names"

    def _get_fix(self, segment, fixed_raw):
        return super()._get_fix(segment, fixed_raw)
Пример #12
0
class Rule_L045(BaseRule):
    """Query defines a CTE (common-table expression) but does not use it.

    **Anti-pattern**

    Defining a CTE that is not used by the query is harmless, but it means
    the code is unnecessary and could be removed.

    .. code-block:: sql

        WITH cte1 AS (
          SELECT a
          FROM t
        ),
        cte2 AS (
          SELECT b
          FROM u
        )

        SELECT *
        FROM cte1

    **Best practice**

    Remove unused CTEs.

    .. code-block:: sql

        WITH cte1 AS (
          SELECT a
          FROM t
        )

        SELECT *
        FROM cte1
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"statement"})

    @classmethod
    def _find_all_ctes(cls, query: Query) -> Iterator[Query]:
        if query.ctes:
            yield query
        for query in query.ctes.values():
            yield from cls._find_all_ctes(query)

    @classmethod
    def _visit_sources(cls, query: Query):
        for selectable in query.selectables:
            for source in query.crawl_sources(selectable.selectable, pop=True):
                if isinstance(source, Query):
                    cls._visit_sources(source)
        for child in query.children:
            cls._visit_sources(child)

    def _eval(self, context: RuleContext) -> EvalResultType:
        result = []
        crawler = SelectCrawler(context.segment, context.dialect)
        if crawler.query_tree:
            # Begin analysis at the final, outer query (key=None).
            queries_with_ctes = list(self._find_all_ctes(crawler.query_tree))
            self._visit_sources(crawler.query_tree)
            for query in queries_with_ctes:
                if query.ctes:
                    result += [
                        LintResult(
                            anchor=query.cte_name_segment,
                            description=f"Query defines CTE "
                            f'"{query.cte_name_segment.raw}" '
                            f"but does not use it.",
                        )
                        for query in query.ctes.values()
                        if query.cte_name_segment
                    ]
        return result
Пример #13
0
class Rule_L018(BaseRule):
    """``WITH`` clause closing bracket should be on a new line.

    **Anti-pattern**

    In this example, the closing bracket is on the same line as CTE.

    .. code-block:: sql
       :force:

        WITH zoo AS (
            SELECT a FROM foo)

        SELECT * FROM zoo

    **Best practice**

    Move the closing bracket on a new line.

    .. code-block:: sql

        WITH zoo AS (
            SELECT a FROM foo
        )

        SELECT * FROM zoo

    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"with_compound_statement"},
                                           provide_raw_stack=True)

    def _eval(self, context: RuleContext):
        """WITH clause closing bracket should be aligned with WITH keyword.

        Look for a with clause and evaluate the position of closing brackets.
        """
        # We only trigger on start_bracket (open parenthesis)
        assert context.segment.is_type("with_compound_statement")
        raw_stack_buff = list(context.raw_stack)
        # Look for the with keyword
        for seg in context.segment.segments:
            if seg.name.lower() == "with":
                seg_line_no = seg.pos_marker.line_no
                break
        else:  # pragma: no cover
            # This *could* happen if the with statement is unparsable,
            # in which case then the user will have to fix that first.
            if any(s.is_type("unparsable") for s in context.segment.segments):
                return LintResult()
            # If it's parsable but we still didn't find a with, then
            # we should raise that.
            raise RuntimeError("Didn't find WITH keyword!")

        # Find the end brackets for the CTE *query* (i.e. ignore optional
        # list of CTE columns).
        cte_end_brackets = IdentitySet()
        for cte in (FunctionalContext(context).segment.children(
                sp.is_type("common_table_expression")).iterate_segments()):
            cte_end_bracket = (cte.children().last(
                sp.is_type("bracketed")).children().last(
                    sp.is_type("end_bracket")))
            if cte_end_bracket:
                cte_end_brackets.add(cte_end_bracket[0])
        for seg in context.segment.iter_segments(
                expanding=["common_table_expression", "bracketed"],
                pass_through=True):
            if seg not in cte_end_brackets:
                if not seg.is_type("start_bracket"):
                    raw_stack_buff.append(seg)
                continue

            if seg.pos_marker.line_no == seg_line_no:
                # Skip if it's the one-line version. That's ok
                continue

            # Is it all whitespace before the bracket on this line?
            assert seg.pos_marker

            contains_non_whitespace = False
            for elem in context.segment.raw_segments:
                if (cast(PositionMarker,
                         elem.pos_marker).line_no == seg.pos_marker.line_no
                        and cast(PositionMarker, elem.pos_marker).line_pos <=
                        seg.pos_marker.line_pos):
                    if elem is seg:
                        break
                    elif elem.is_type("newline"):
                        contains_non_whitespace = False
                    elif not elem.is_type("dedent") and not elem.is_type(
                            "whitespace"):
                        contains_non_whitespace = True

            if contains_non_whitespace:
                # We have to move it to a newline
                return LintResult(
                    anchor=seg,
                    fixes=[LintFix.create_before(
                        seg,
                        [
                            NewlineSegment(),
                        ],
                    )],
                )
Пример #14
0
class Rule_L059(BaseRule):
    """Unnecessary quoted identifier.

    This rule will fail if the quotes used to quote an identifier are (un)necessary
    depending on the ``force_quote_identifier`` configuration.

    When ``prefer_quoted_identifiers = False`` (default behaviour), the quotes are
    unnecessary, except for reserved keywords and special characters in identifiers.

    .. note::
       This rule is disabled by default for Postgres and Snowflake because they allow
       quotes as part of the column name. In other words, ``date`` and ``"date"`` are
       two different columns.

       It can be enabled with the ``force_enable = True`` flag.

    **Anti-pattern**

    In this example, a valid unquoted identifier,
    that is also not a reserved keyword, is needlessly quoted.

    .. code-block:: sql

        SELECT 123 as "foo"

    **Best practice**

    Use unquoted identifiers where possible.

    .. code-block:: sql

        SELECT 123 as foo

    When ``prefer_quoted_identifiers = True``, the quotes are always necessary, no
    matter if the identifier is valid, a reserved keyword, or contains special
    characters.

    .. note::
       Note due to different quotes being used by different dialects supported by
       `SQLFluff`, and those quotes meaning different things in different contexts,
       this mode is not ``sqlfluff fix`` compatible.

    **Anti-pattern**

    In this example, a valid unquoted identifier, that is also not a reserved keyword,
    is required to be quoted.

    .. code-block:: sql

        SELECT 123 as foo

    **Best practice**
    Use quoted identifiers.

    .. code-block:: sql

        SELECT 123 as "foo" -- For ANSI, ...
        -- or
        SELECT 123 as `foo` -- For BigQuery, MySql, ...

    """

    groups = ("all", )
    config_keywords = [
        "prefer_quoted_identifiers",
        "ignore_words",
        "ignore_words_regex",
        "force_enable",
    ]
    crawl_behaviour = SegmentSeekerCrawler(
        {"quoted_identifier", "naked_identifier"})
    _dialects_allowing_quotes_in_column_names = ["postgres", "snowflake"]

    # Ignore "password_auth" type to allow quotes around passwords within
    # `CREATE USER` statements in Exasol dialect.
    _ignore_types: List[str] = ["password_auth"]

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Unnecessary quoted identifier."""
        # Config type hints
        self.prefer_quoted_identifiers: bool
        self.ignore_words: str
        self.ignore_words_regex: str
        self.force_enable: bool
        # Some dialects allow quotes as PART OF the column name. In other words,
        # these are two different columns:
        # - date
        # - "date"
        # For safety, disable this rule by default in those dialects.
        if (context.dialect.name
                in self._dialects_allowing_quotes_in_column_names
                and not self.force_enable):
            return LintResult()

        # Ignore some segment types
        if FunctionalContext(context).parent_stack.any(
                sp.is_type(*self._ignore_types)):
            return None

        if self.prefer_quoted_identifiers:
            context_policy = "naked_identifier"
            identifier_contents = context.segment.raw
        else:
            context_policy = "quoted_identifier"
            identifier_contents = context.segment.raw[1:-1]

        # Get the ignore_words_list configuration.
        try:
            ignore_words_list = self.ignore_words_list
        except AttributeError:
            # First-time only, read the settings from configuration. This is
            # very slow.
            ignore_words_list = self._init_ignore_words_list()

        # Skip if in ignore list
        if ignore_words_list and identifier_contents.lower(
        ) in ignore_words_list:
            return None

        # Skip if matches ignore regex
        if self.ignore_words_regex and regex.search(self.ignore_words_regex,
                                                    identifier_contents):
            return LintResult(memory=context.memory)

        # Ignore the segments that are not of the same type as the defined policy above.
        # Also TSQL has a keyword called QUOTED_IDENTIFIER which maps to the name so
        # need to explicity check for that.
        if not context.segment.is_type(
                context_policy) or context.segment.raw.lower() in (
                    "quoted_identifier",
                    "naked_identifier",
                ):
            return None

        # Manage cases of identifiers must be quoted first.
        # Naked identifiers are _de facto_ making this rule fail as configuration forces
        # them to be quoted.
        # In this case, it cannot be fixed as which quote to use is dialect dependent
        if self.prefer_quoted_identifiers:
            return LintResult(
                context.segment,
                description=f"Missing quoted identifier {identifier_contents}.",
            )

        # Now we only deal with NOT forced quoted identifiers configuration
        # (meaning prefer_quoted_identifiers=False).

        # Extract contents of outer quotes.
        quoted_identifier_contents = context.segment.raw[1:-1]

        # Retrieve NakedIdentifierSegment RegexParser for the dialect.
        naked_identifier_parser = context.dialect._library[
            "NakedIdentifierSegment"]
        IdentifierSegment = cast(
            Type[CodeSegment],
            context.dialect.get_segment("IdentifierSegment"))

        # Check if quoted_identifier_contents could be a valid naked identifier
        # and that it is not a reserved keyword.
        if (regex.fullmatch(
                naked_identifier_parser.template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is not None) and (regex.fullmatch(
                naked_identifier_parser.anti_template,
                quoted_identifier_contents,
                regex.IGNORECASE,
        ) is None):
            return LintResult(
                context.segment,
                fixes=[
                    LintFix.replace(
                        context.segment,
                        [
                            IdentifierSegment(
                                raw=quoted_identifier_contents,
                                type="naked_identifier",
                            )
                        ],
                    )
                ],
                description=
                f"Unnecessary quoted identifier {context.segment.raw}.",
            )

        return None

    def _init_ignore_words_list(self):
        """Called first time rule is evaluated to fetch & cache the policy."""
        ignore_words_config: str = str(getattr(self, "ignore_words"))
        if ignore_words_config and ignore_words_config != "None":
            self.ignore_words_list = self.split_comma_separated_string(
                ignore_words_config.lower())
        else:
            self.ignore_words_list = []

        return self.ignore_words_list
Пример #15
0
class Rule_L015(BaseRule):
    """``DISTINCT`` used with parentheses.

    **Anti-pattern**

    In this example, parentheses are not needed and confuse
    ``DISTINCT`` with a function. The parentheses can also be misleading
    about which columns are affected by the ``DISTINCT`` (all the columns!).

    .. code-block:: sql

        SELECT DISTINCT(a), b FROM foo

    **Best practice**

    Remove parentheses to be clear that the ``DISTINCT`` applies to
    both columns.

    .. code-block:: sql

        SELECT DISTINCT a, b FROM foo

    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"select_clause"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Looking for DISTINCT before a bracket.

        Look for DISTINCT keyword immediately followed by open parenthesis.
        """
        # We trigger on `select_clause` and look for `select_clause_modifier`
        assert context.segment.is_type("select_clause")
        children = FunctionalContext(context).segment.children()
        modifier = children.select(sp.is_type("select_clause_modifier"))
        first_element = children.select(
            sp.is_type("select_clause_element")).first()
        if not modifier or not first_element:
            return None
        # is the first element only an expression with only brackets?
        expression = (first_element.children(sp.is_type("expression")).first()
                      or first_element)
        bracketed = expression.children(sp.is_type("bracketed")).first()
        if bracketed:
            fixes = []
            # If there's nothing else in the expression, remove the brackets.
            if len(expression[0].segments) == 1:
                # Remove the brackets and strip any meta segments.
                fixes.append(
                    LintFix.replace(
                        bracketed[0],
                        self.filter_meta(bracketed[0].segments)[1:-1]), )
            # If no whitespace between DISTINCT and expression, add it.
            if not children.select(sp.is_whitespace(),
                                   start_seg=modifier[0],
                                   stop_seg=first_element[0]):
                fixes.append(
                    LintFix.create_before(
                        first_element[0],
                        [WhitespaceSegment()],
                    ))
            # If no fixes, no problem.
            if fixes:
                return LintResult(anchor=modifier[0], fixes=fixes)
        return None
Пример #16
0
class Rule_L058(BaseRule):
    """Nested ``CASE`` statement in ``ELSE`` clause could be flattened.

    **Anti-pattern**

    In this example, the outer ``CASE``'s ``ELSE`` is an unnecessary, nested ``CASE``.

    .. code-block:: sql

        SELECT
          CASE
            WHEN species = 'Cat' THEN 'Meow'
            ELSE
            CASE
               WHEN species = 'Dog' THEN 'Woof'
            END
          END as sound
        FROM mytable

    **Best practice**

    Move the body of the inner ``CASE`` to the end of the outer one.

    .. code-block:: sql

        SELECT
          CASE
            WHEN species = 'Cat' THEN 'Meow'
            WHEN species = 'Dog' THEN 'Woof'
          END AS sound
        FROM mytable

    """

    groups = ("all", )
    crawl_behaviour = SegmentSeekerCrawler({"case_expression"})

    def _eval(self, context: RuleContext) -> LintResult:
        """Nested CASE statement in ELSE clause could be flattened."""
        segment = FunctionalContext(context).segment
        assert segment.select(sp.is_type("case_expression"))
        case1_children = segment.children()
        case1_last_when = case1_children.last(sp.is_type("when_clause")).get()
        case1_else_clause = case1_children.select(sp.is_type("else_clause"))
        case1_else_expressions = case1_else_clause.children(
            sp.is_type("expression"))
        expression_children = case1_else_expressions.children()
        case2 = expression_children.select(sp.is_type("case_expression"))
        # The len() checks below are for safety, to ensure the CASE inside
        # the ELSE is not part of a larger expression. In that case, it's
        # not safe to simplify in this way -- we'd be deleting other code.
        if (not case1_last_when or len(case1_else_expressions) > 1
                or len(expression_children) > 1 or not case2):
            return LintResult()

        # We can assert that this exists because of the previous check.
        assert case1_last_when
        # We can also assert that we'll also have an else clause because
        # otherwise the case2 check above would fail.
        case1_else_clause_seg = case1_else_clause.get()
        assert case1_else_clause_seg

        # Delete stuff between the last "WHEN" clause and the "ELSE" clause.
        case1_to_delete = case1_children.select(start_seg=case1_last_when,
                                                stop_seg=case1_else_clause_seg)

        # Delete the nested "CASE" expression.
        fixes = case1_to_delete.apply(lambda seg: LintFix.delete(seg))

        # Determine the indentation to use when we move the nested "WHEN"
        # and "ELSE" clauses, based on the indentation of case1_last_when.
        # If no whitespace segments found, use default indent.
        indent = (case1_children.select(
            stop_seg=case1_last_when).reversed().select(
                sp.is_type("whitespace")))
        indent_str = "".join(seg.raw
                             for seg in indent) if indent else self.indent

        # Move the nested "when" and "else" clauses after the last outer
        # "when".
        nested_clauses = case2.children(
            sp.is_type("when_clause", "else_clause"))
        create_after_last_when = nested_clauses.apply(
            lambda seg: [NewlineSegment(),
                         WhitespaceSegment(indent_str), seg])
        segments = [
            item for sublist in create_after_last_when for item in sublist
        ]
        fixes.append(
            LintFix.create_after(case1_last_when, segments, source=segments))

        # Delete the outer "else" clause.
        fixes.append(LintFix.delete(case1_else_clause_seg))
        return LintResult(case2[0], fixes=fixes)
Пример #17
0
class Rule_L022(BaseRule):
    """Blank line expected but not found after CTE closing bracket.

    **Anti-pattern**

    There is no blank line after the CTE closing bracket. In queries with many
    CTEs, this hinders readability.

    .. code-block:: sql

        WITH plop AS (
            SELECT * FROM foo
        )
        SELECT a FROM plop

    **Best practice**

    Add a blank line.

    .. code-block:: sql

        WITH plop AS (
            SELECT * FROM foo
        )

        SELECT a FROM plop

    """

    groups = ("all", "core")
    config_keywords = ["comma_style"]
    crawl_behaviour = SegmentSeekerCrawler({"with_compound_statement"})

    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Blank line expected but not found after CTE definition."""
        # Config type hints
        self.comma_style: str

        error_buffer = []
        assert context.segment.is_type("with_compound_statement")
        # First we need to find all the commas, the end brackets, the
        # things that come after that and the blank lines in between.

        # Find all the closing brackets. They are our anchor points.
        bracket_indices = []
        expanded_segments = list(
            context.segment.iter_segments(
                expanding=["common_table_expression"]))
        for idx, seg in enumerate(expanded_segments):
            if seg.is_type("bracketed"):
                bracket_indices.append(idx)

        # Work through each point and deal with it individually
        for bracket_idx in bracket_indices:
            forward_slice = expanded_segments[bracket_idx:]
            seg_idx = 1
            line_idx = 0
            comma_seg_idx = 0
            blank_lines = 0
            comma_line_idx = None
            line_blank = False
            comma_style = None
            line_starts = {}
            comment_lines = []

            self.logger.info(
                "## CTE closing bracket found at %s, idx: %s. Forward slice: %.20r",
                forward_slice[0].pos_marker,
                bracket_idx,
                "".join(elem.raw for elem in forward_slice),
            )

            # Work forward to map out the following segments.
            while (forward_slice[seg_idx].is_type("comma")
                   or not forward_slice[seg_idx].is_code):
                if forward_slice[seg_idx].is_type("newline"):
                    if line_blank:
                        # It's a blank line!
                        blank_lines += 1
                    line_blank = True
                    line_idx += 1
                    line_starts[line_idx] = seg_idx + 1
                elif forward_slice[seg_idx].is_type("comment"):
                    # Lines with comments aren't blank
                    line_blank = False
                    comment_lines.append(line_idx)
                elif forward_slice[seg_idx].is_type("comma"):
                    # Keep track of where the comma is.
                    # We'll evaluate it later.
                    comma_line_idx = line_idx
                    comma_seg_idx = seg_idx
                seg_idx += 1

            # Infer the comma style (NB this could be different for each case!)
            if comma_line_idx is None:
                comma_style = "final"
            elif line_idx == 0:
                comma_style = "oneline"
            elif comma_line_idx == 0:
                comma_style = "trailing"
            elif comma_line_idx == line_idx:
                comma_style = "leading"
            else:
                comma_style = "floating"

            # Readout of findings
            self.logger.info(
                "blank_lines: %s, comma_line_idx: %s. final_line_idx: %s, "
                "final_seg_idx: %s",
                blank_lines,
                comma_line_idx,
                line_idx,
                seg_idx,
            )
            self.logger.info(
                "comma_style: %r, line_starts: %r, comment_lines: %r",
                comma_style,
                line_starts,
                comment_lines,
            )

            if blank_lines < 1:
                # We've got an issue
                self.logger.info("!! Found CTE without enough blank lines.")

                # Based on the current location of the comma we insert newlines
                # to correct the issue.
                fix_type = "create_before"  # In most cases we just insert newlines.
                if comma_style == "oneline":
                    # Here we respect the target comma style to insert at the
                    # relevant point.
                    if self.comma_style == "trailing":
                        # Add a blank line after the comma
                        fix_point = forward_slice[comma_seg_idx + 1]
                        # Optionally here, if the segment we've landed on is
                        # whitespace then we REPLACE it rather than inserting.
                        if forward_slice[comma_seg_idx +
                                         1].is_type("whitespace"):
                            fix_type = "replace"
                    elif self.comma_style == "leading":
                        # Add a blank line before the comma
                        fix_point = forward_slice[comma_seg_idx]
                    # In both cases it's a double newline.
                    num_newlines = 2
                else:
                    # In the following cases we only care which one we're in
                    # when comments don't get in the way. If they *do*, then
                    # we just work around them.
                    if not comment_lines or line_idx - 1 not in comment_lines:
                        self.logger.info("Comment routines not applicable")
                        if comma_style in ("trailing", "final", "floating"):
                            # Detected an existing trailing comma or it's a final
                            # CTE, OR the comma isn't leading or trailing.
                            # If the preceding segment is whitespace, replace it
                            if forward_slice[seg_idx -
                                             1].is_type("whitespace"):
                                fix_point = forward_slice[seg_idx - 1]
                                fix_type = "replace"
                            else:
                                # Otherwise add a single newline before the end
                                # content.
                                fix_point = forward_slice[seg_idx]
                        elif comma_style == "leading":
                            # Detected an existing leading comma.
                            fix_point = forward_slice[comma_seg_idx]
                    else:
                        self.logger.info("Handling preceding comments")
                        offset = 1
                        while line_idx - offset in comment_lines:
                            offset += 1
                        # If the offset - 1 equals the line_idx then there aren't
                        # really any comment-only lines (ref #2945).
                        # Reset to line_idx
                        fix_point = forward_slice[line_starts[line_idx -
                                                              (offset - 1)
                                                              or line_idx]]
                    num_newlines = 1

                fixes = [
                    LintFix(
                        fix_type,
                        fix_point,
                        [NewlineSegment()] * num_newlines,
                    )
                ]
                # Create a result, anchored on the start of the next content.
                error_buffer.append(
                    LintResult(anchor=forward_slice[seg_idx], fixes=fixes))
        # Return the buffer if we have one.
        return error_buffer or None
Пример #18
0
class Rule_L037(BaseRule):
    """Ambiguous ordering directions for columns in order by clause.

    **Anti-pattern**

    .. code-block:: sql

        SELECT
            a, b
        FROM foo
        ORDER BY a, b DESC

    **Best practice**

    If any columns in the ``ORDER BY`` clause specify ``ASC`` or ``DESC``, they should
    all do so.

    .. code-block:: sql

        SELECT
            a, b
        FROM foo
        ORDER BY a ASC, b DESC
    """

    groups = ("all", )
    crawl_behaviour = SegmentSeekerCrawler({"orderby_clause"})

    @staticmethod
    def _get_orderby_info(segment: BaseSegment) -> List[OrderByColumnInfo]:
        assert segment.is_type("orderby_clause")

        result = []
        found_column_reference = False
        ordering_reference = None
        for child_segment in segment.segments:
            if child_segment.is_type("column_reference"):
                found_column_reference = True
            elif child_segment.is_type("keyword") and child_segment.name in (
                    "asc",
                    "desc",
            ):
                ordering_reference = child_segment.name
            elif found_column_reference and child_segment.type not in [
                    "keyword",
                    "whitespace",
                    "indent",
                    "dedent",
            ]:
                result.append(
                    OrderByColumnInfo(separator=child_segment,
                                      order=ordering_reference))

                # Reset findings
                found_column_reference = False
                ordering_reference = None

        # Special handling for last column
        if found_column_reference:
            result.append(
                OrderByColumnInfo(separator=segment.segments[-1],
                                  order=ordering_reference))
        return result

    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Ambiguous ordering directions for columns in order by clause.

        This rule checks if some ORDER BY columns explicitly specify ASC or
        DESC and some don't.
        """
        # We only trigger on orderby_clause
        lint_fixes = []
        orderby_spec = self._get_orderby_info(context.segment)
        order_types = {o.order for o in orderby_spec}
        # If ALL columns or NO columns explicitly specify ASC/DESC, all is
        # well.
        if None not in order_types or order_types == {None}:
            return None

        # There's a mix of explicit and default sort order. Make everything
        # explicit.
        for col_info in orderby_spec:
            if not col_info.order:
                # Since ASC is default in SQL, add in ASC for fix
                lint_fixes.append(
                    LintFix.create_before(
                        col_info.separator,
                        [WhitespaceSegment(),
                         KeywordSegment("ASC")],
                    ))

        return [
            LintResult(
                anchor=context.segment,
                fixes=lint_fixes,
                description=
                ("Ambiguous order by clause. Order by clauses should specify "
                 "order direction for ALL columns or NO columns."),
            )
        ]
Пример #19
0
class Rule_L023(BaseRule):
    """Single whitespace expected after ``AS`` in ``WITH`` clause.

    **Anti-pattern**

    .. code-block:: sql

        WITH plop AS(
            SELECT * FROM foo
        )

        SELECT a FROM plop


    **Best practice**

    Add a space after ``AS``, to avoid confusing it for a function.
    The ``•`` character represents a space.

    .. code-block:: sql
       :force:

        WITH plop AS•(
            SELECT * FROM foo
        )

        SELECT a FROM plop
    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"with_compound_statement"})
    pre_segment_identifier = ("raw_upper", "AS")
    post_segment_identifier = ("type", "bracketed")
    allow_newline = False  # hard-coded, could be configurable
    expand_children: Optional[List[str]] = ["common_table_expression"]

    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Single whitespace expected in mother middle segment."""
        error_buffer: List[LintResult] = []
        last_code = None
        mid_segs: List[BaseSegment] = []
        for seg in context.segment.iter_segments(
                expanding=self.expand_children):
            if seg.is_code:
                if (last_code and self.matches_target_tuples(
                        last_code, [self.pre_segment_identifier])
                        and self.matches_target_tuples(
                            seg, [self.post_segment_identifier])):
                    # Do we actually have the right amount of whitespace?
                    raw_inner = "".join(s.raw for s in mid_segs)
                    if raw_inner != " " and not (self.allow_newline and any(
                            s.is_type("newline") for s in mid_segs)):
                        if not raw_inner.strip():
                            # There's some whitespace and/or newlines, or nothing
                            fixes = []
                            if raw_inner:
                                # There's whitespace and/or newlines. Drop those.
                                fixes += [
                                    LintFix.delete(mid_seg)
                                    for mid_seg in mid_segs
                                ]
                            # Enforce a single space
                            fixes += [
                                LintFix.create_before(
                                    seg,
                                    [WhitespaceSegment()],
                                )
                            ]
                        else:
                            # Don't otherwise suggest a fix for now.
                            # Only whitespace & newlines are covered.
                            # At least a comment section between `AS` and `(` can
                            # result in an unfixable error.
                            # TODO: Enable more complex fixing here.
                            fixes = None  # pragma: no cover
                        error_buffer.append(
                            LintResult(anchor=last_code, fixes=fixes))
                mid_segs = []
                if not seg.is_meta:
                    last_code = seg
            else:
                mid_segs.append(seg)
        return error_buffer or None
Пример #20
0
class Rule_L020(BaseRule):
    """Table aliases should be unique within each clause.

    Reusing table aliases is very likely a coding error.

    **Anti-pattern**

    In this example, the alias ``t`` is reused for two different tables:

    .. code-block:: sql

        SELECT
            t.a,
            t.b
        FROM foo AS t, bar AS t

        -- This can also happen when using schemas where the
        -- implicit alias is the table name:

        SELECT
            a,
            b
        FROM
            2020.foo,
            2021.foo

    **Best practice**

    Make all tables have a unique alias.

    .. code-block:: sql

        SELECT
            f.a,
            b.b
        FROM foo AS f, bar AS b

        -- Also use explicit aliases when referencing two tables
        -- with the same name from two different schemas.

        SELECT
            f1.a,
            f2.b
        FROM
            2020.foo AS f1,
            2021.foo AS f2

    """

    groups: Tuple[str, ...] = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"select_statement"})

    def _lint_references_and_aliases(
        self,
        table_aliases: List[AliasInfo],
        standalone_aliases: List[str],
        references: List[BaseSegment],
        col_aliases: List[ColumnAliasInfo],
        using_cols: List[str],
        parent_select: Optional[BaseSegment],
    ) -> Optional[List[LintResult]]:
        """Check whether any aliases are duplicates.

        NB: Subclasses of this error should override this function.

        """
        # Are any of the aliases the same?
        duplicate = set()
        for a1, a2 in itertools.combinations(table_aliases, 2):
            # Compare the strings
            if a1.ref_str == a2.ref_str and a1.ref_str:
                duplicate.add(a2)
        if duplicate:
            return [
                LintResult(
                    # Reference the element, not the string.
                    anchor=aliases.segment,
                    description=(
                        "Duplicate table alias {!r}. Table " "aliases should be unique."
                    ).format(aliases.ref_str),
                )
                for aliases in duplicate
            ]
        else:
            return None

    def _eval(self, context: RuleContext) -> EvalResultType:
        """Get References and Aliases and allow linting.

        This rule covers a lot of potential cases of odd usages of
        references, see the code for each of the potential cases.

        Subclasses of this rule should override the
        `_lint_references_and_aliases` method.
        """
        assert context.segment.is_type("select_statement")
        select_info = get_select_statement_info(context.segment, context.dialect)
        if not select_info:
            return None

        # Work out if we have a parent select function
        parent_select = None
        for seg in reversed(context.parent_stack):
            if seg.is_type("select_statement"):
                parent_select = seg
                break

        # Pass them all to the function that does all the work.
        # NB: Subclasses of this rules should override the function below
        return self._lint_references_and_aliases(
            select_info.table_aliases,
            select_info.standalone_aliases,
            select_info.reference_buffer,
            select_info.col_aliases,
            select_info.using_cols,
            parent_select,
        )
Пример #21
0
class Rule_L010(BaseRule):
    """Inconsistent capitalisation of keywords.

    **Anti-pattern**

    In this example, ``select`` is in lower-case whereas ``FROM`` is in upper-case.

    .. code-block:: sql

        select
            a
        FROM foo

    **Best practice**

    Make all keywords either in upper-case or in lower-case.

    .. code-block:: sql

        SELECT
            a
        FROM foo

        -- Also good

        select
            a
        from foo
    """

    groups: Tuple[str, ...] = ("all", "core")
    lint_phase = "post"
    # Binary operators behave like keywords too.
    crawl_behaviour = SegmentSeekerCrawler(
        {"keyword", "binary_operator", "date_part"})
    # Skip boolean and null literals (which are also keywords)
    # as they have their own rule (L040)
    _exclude_elements: List[Tuple[str, str]] = [
        ("type", "null_literal"),
        ("type", "boolean_literal"),
        ("parenttype", "data_type"),
        ("parenttype", "datetime_type_identifier"),
        ("parenttype", "primitive_type"),
    ]
    config_keywords = [
        "capitalisation_policy", "ignore_words", "ignore_words_regex"
    ]
    # Human readable target elem for description
    _description_elem = "Keywords"

    def _eval(self, context: RuleContext) -> Optional[List[LintResult]]:
        """Inconsistent capitalisation of keywords.

        We use the `memory` feature here to keep track of cases known to be
        INconsistent with what we've seen so far as well as the top choice
        for what the possible case is.

        """
        # Skip if not an element of the specified type/name
        parent: Optional[BaseSegment] = (context.parent_stack[-1]
                                         if context.parent_stack else None)
        if self.matches_target_tuples(context.segment, self._exclude_elements,
                                      parent):
            return [LintResult(memory=context.memory)]

        return [self._handle_segment(context.segment, context.memory)]

    def _handle_segment(self, segment, memory) -> LintResult:
        # NOTE: this mutates the memory field.
        self.logger.info("_handle_segment: %s, %s", segment,
                         segment.get_type())
        # Config type hints
        self.ignore_words_regex: str

        # Get the capitalisation policy configuration.
        try:
            cap_policy = self.cap_policy
            cap_policy_opts = self.cap_policy_opts
            ignore_words_list = self.ignore_words_list
        except AttributeError:
            # First-time only, read the settings from configuration. This is
            # very slow.
            (
                cap_policy,
                cap_policy_opts,
                ignore_words_list,
            ) = self._init_capitalisation_policy()

        # Skip if in ignore list
        if ignore_words_list and segment.raw.lower() in ignore_words_list:
            return LintResult(memory=memory)

        # Skip if matches ignore regex
        if self.ignore_words_regex and regex.search(self.ignore_words_regex,
                                                    segment.raw):
            return LintResult(memory=memory)

        # Skip if templated.
        if segment.is_templated:
            return LintResult(memory=memory)

        # Skip if empty.
        if not segment.raw:
            return LintResult(memory=memory)

        refuted_cases = memory.get("refuted_cases", set())

        # Which cases are definitely inconsistent with the segment?
        if segment.raw[0] != segment.raw[0].upper():
            refuted_cases.update(["upper", "capitalise", "pascal"])
            if segment.raw != segment.raw.lower():
                refuted_cases.update(["lower"])
        else:
            refuted_cases.update(["lower"])
            if segment.raw != segment.raw.upper():
                refuted_cases.update(["upper"])
            if segment.raw != segment.raw.capitalize():
                refuted_cases.update(["capitalise"])
            if not segment.raw.isalnum():
                refuted_cases.update(["pascal"])

        # Update the memory
        memory["refuted_cases"] = refuted_cases

        self.logger.debug(
            f"Refuted cases after segment '{segment.raw}': {refuted_cases}")

        # Skip if no inconsistencies, otherwise compute a concrete policy
        # to convert to.
        if cap_policy == "consistent":
            possible_cases = [
                c for c in cap_policy_opts if c not in refuted_cases
            ]
            self.logger.debug(
                f"Possible cases after segment '{segment.raw}': {possible_cases}"
            )
            if possible_cases:
                # Save the latest possible case and skip
                memory["latest_possible_case"] = possible_cases[0]
                self.logger.debug(
                    f"Consistent capitalization, returning with memory: {memory}"
                )
                return LintResult(memory=memory)
            else:
                concrete_policy = memory.get("latest_possible_case", "upper")
                self.logger.debug(
                    f"Getting concrete policy '{concrete_policy}' from memory")
        else:
            if cap_policy not in refuted_cases:
                # Skip
                self.logger.debug(
                    f"Consistent capitalization {cap_policy}, returning with "
                    f"memory: {memory}")
                return LintResult(memory=memory)
            else:
                concrete_policy = cap_policy
                self.logger.debug(
                    f"Setting concrete policy '{concrete_policy}' from cap_policy"
                )

        # Set the fixed to same as initial in case any of below don't match
        fixed_raw = segment.raw
        # We need to change the segment to match the concrete policy
        if concrete_policy in ["upper", "lower", "capitalise"]:
            if concrete_policy == "upper":
                fixed_raw = fixed_raw.upper()
            elif concrete_policy == "lower":
                fixed_raw = fixed_raw.lower()
            elif concrete_policy == "capitalise":
                fixed_raw = fixed_raw.capitalize()
        elif concrete_policy == "pascal":
            # For Pascal we set the first letter in each "word" to uppercase
            # We do not lowercase other letters to allow for PascalCase style
            # words. This does mean we allow all UPPERCASE and also don't
            # correct Pascalcase to PascalCase, but there's only so much we can
            # do. We do correct underscore_words to Underscore_Words.
            fixed_raw = regex.sub(
                "([^a-zA-Z0-9]+|^)([a-zA-Z0-9])([a-zA-Z0-9]*)",
                lambda match: match.group(1) + match.group(2).upper() + match.
                group(3),
                segment.raw,
            )

        if fixed_raw == segment.raw:
            # No need to fix
            self.logger.debug(
                f"Capitalisation of segment '{segment.raw}' already OK with "
                f"policy '{concrete_policy}', returning with memory {memory}")
            return LintResult(memory=memory)
        else:
            # build description based on the policy in use
            consistency = "consistently " if cap_policy == "consistent" else ""

            if concrete_policy in ["upper", "lower"]:
                policy = f"{concrete_policy} case."
            elif concrete_policy == "capitalise":
                policy = "capitalised."
            elif concrete_policy == "pascal":
                policy = "pascal case."

            # Return the fixed segment
            self.logger.debug(
                f"INCONSISTENT Capitalisation of segment '{segment.raw}', "
                f"fixing to '{fixed_raw}' and returning with memory {memory}")
            return LintResult(
                anchor=segment,
                fixes=[self._get_fix(segment, fixed_raw)],
                memory=memory,
                description=
                f"{self._description_elem} must be {consistency}{policy}",
            )

    def _get_fix(self, segment, fixed_raw):
        """Given a segment found to have a fix, returns a LintFix for it.

        May be overridden by subclasses, which is useful when the parse tree
        structure varies from this simple base case.
        """
        return LintFix.replace(segment, [segment.edit(fixed_raw)])

    def _init_capitalisation_policy(self):
        """Called first time rule is evaluated to fetch & cache the policy."""
        cap_policy_name = next(k for k in self.config_keywords
                               if k.endswith("capitalisation_policy"))
        self.cap_policy = getattr(self, cap_policy_name)
        self.cap_policy_opts = [
            opt for opt in get_config_info()[cap_policy_name]["validation"]
            if opt != "consistent"
        ]
        # Use str() as L040 uses bools which might otherwise be read as bool
        ignore_words_config = str(getattr(self, "ignore_words"))
        if ignore_words_config and ignore_words_config != "None":
            self.ignore_words_list = self.split_comma_separated_string(
                ignore_words_config.lower())
        else:
            self.ignore_words_list = []
        self.logger.debug(
            f"Selected '{cap_policy_name}': '{self.cap_policy}' from options "
            f"{self.cap_policy_opts}")
        cap_policy = self.cap_policy
        cap_policy_opts = self.cap_policy_opts
        ignore_words_list = self.ignore_words_list
        return cap_policy, cap_policy_opts, ignore_words_list
Пример #22
0
class Rule_L066(BaseRule):
    """Enforce table alias lengths in from clauses and join conditions.

    **Anti-pattern**

    In this example, alias ``o`` is used for the orders table.

    .. code-block:: sql

        SELECT
            SUM(o.amount) as order_amount,
        FROM orders as o


    **Best practice**

    Avoid aliases. Avoid short aliases when aliases are necessary.

    See also: L031.

    .. code-block:: sql

        SELECT
            SUM(orders.amount) as order_amount,
        FROM orders

        SELECT
            replacement_orders.amount,
            previous_orders.amount
        FROM
            orders AS replacement_orders
        JOIN
            orders AS previous_orders
            ON replacement_orders.id = previous_orders.replacement_id
    """

    groups = ("all",)
    config_keywords = ["min_alias_length", "max_alias_length"]
    crawl_behaviour = SegmentSeekerCrawler({"select_statement"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Identify aliases in from clause and join conditions.

        Find base table, table expressions in join, and other expressions in select
        clause and decide if it's needed to report them.
        """
        self.min_alias_length: Optional[int]
        self.max_alias_length: Optional[int]

        assert context.segment.is_type("select_statement")
        children = FunctionalContext(context).segment.children()
        from_expression_elements = children.recursive_crawl("from_expression_element")

        return self._lint_aliases(from_expression_elements) or None

    def _lint_aliases(self, from_expression_elements):
        """Lint all table aliases."""
        # A buffer to keep any violations.
        violation_buff = []

        # For each table, check whether it is aliased, and if so check the
        # lengths.
        for from_expression_element in from_expression_elements:

            table_expression = from_expression_element.get_child("table_expression")
            table_ref = (
                table_expression.get_child("object_reference")
                if table_expression
                else None
            )

            # If the from_expression_element has no object_reference - skip it
            # An example case is a lateral flatten, where we have a function segment
            # instead of a table_reference segment.
            if not table_ref:
                continue

            # If there's no alias expression - skip it
            alias_exp_ref = from_expression_element.get_child("alias_expression")
            if alias_exp_ref is None:
                continue

            alias_identifier_ref = alias_exp_ref.get_child("identifier")

            if self.min_alias_length is not None:
                if len(alias_identifier_ref.raw) < self.min_alias_length:
                    violation_buff.append(
                        LintResult(
                            anchor=alias_identifier_ref,
                            description=(
                                "Aliases should be at least {} character(s) long."
                            ).format(self.min_alias_length),
                        )
                    )

            if self.max_alias_length is not None:
                if len(alias_identifier_ref.raw) > self.max_alias_length:
                    violation_buff.append(
                        LintResult(
                            anchor=alias_identifier_ref,
                            description=(
                                "Aliases should be no more than {} character(s) long."
                            ).format(self.max_alias_length),
                        )
                    )

        return violation_buff or None
Пример #23
0
class Rule_L044(BaseRule):
    """Query produces an unknown number of result columns.

    **Anti-pattern**

    Querying all columns using ``*`` produces a query result where the number
    or ordering of columns changes if the upstream table's schema changes.
    This should generally be avoided because it can cause slow performance,
    cause important schema changes to go undetected, or break production code.
    For example:

    * If a query does ``SELECT t.*`` and is expected to return columns ``a``, ``b``,
      and ``c``, the actual columns returned will be wrong/different if columns
      are added to or deleted from the input table.
    * ``UNION`` and ``DIFFERENCE`` clauses require the inputs have the same number
      of columns (and compatible types).
    * ``JOIN`` queries may break due to new column name conflicts, e.g. the
      query references a column ``c`` which initially existed in only one input
      table but a column of the same name is added to another table.
    * ``CREATE TABLE (<<column schema>>) AS SELECT *``


    .. code-block:: sql

        WITH cte AS (
            SELECT * FROM foo
        )

        SELECT * FROM cte
        UNION
        SELECT a, b FROM t

    **Best practice**

    Somewhere along the "path" to the source data, specify columns explicitly.

    .. code-block:: sql

        WITH cte AS (
            SELECT * FROM foo
        )

        SELECT a, b FROM cte
        UNION
        SELECT a, b FROM t

    """

    groups = ("all",)
    crawl_behaviour = SegmentSeekerCrawler(set(_START_TYPES))

    def _handle_alias(self, selectable, alias_info, query):
        select_info_target = SelectCrawler.get(
            query, alias_info.from_expression_element
        )[0]
        if isinstance(select_info_target, str):
            # It's an alias to an external table whose
            # number of columns could vary without our
            # knowledge. Thus, warn.
            self.logger.debug(
                f"Query target {select_info_target} is external. Generating warning."
            )
            raise RuleFailure(selectable.selectable)
        else:
            # Handle nested SELECT.
            self._analyze_result_columns(select_info_target)

    def _analyze_result_columns(self, query: Query):
        """Given info on a list of SELECTs, determine whether to warn."""
        # Recursively walk from the given query (select_info_list) to any
        # wildcard columns in the select targets. If every wildcard evdentually
        # resolves to a query without wildcards, all is well. Otherwise, warn.
        if not query.selectables:
            return  # pragma: no cover
        for selectable in query.selectables:
            self.logger.debug(f"Analyzing query: {selectable.selectable.raw}")
            for wildcard in selectable.get_wildcard_info():
                if wildcard.tables:
                    for wildcard_table in wildcard.tables:
                        self.logger.debug(
                            f"Wildcard: {wildcard.segment.raw} has target "
                            "{wildcard_table}"
                        )
                        # Is it an alias?
                        alias_info = selectable.find_alias(wildcard_table)
                        if alias_info:
                            # Found the alias matching the wildcard. Recurse,
                            # analyzing the query associated with that alias.
                            self._handle_alias(selectable, alias_info, query)
                        else:
                            # Not an alias. Is it a CTE?
                            cte = query.lookup_cte(wildcard_table)
                            if cte:
                                # Wildcard refers to a CTE. Analyze it.
                                self._analyze_result_columns(cte)
                            else:
                                # Not CTE, not table alias. Presumably an
                                # external table. Warn.
                                self.logger.debug(
                                    f"Query target {wildcard_table} is external. "
                                    "Generating warning."
                                )
                                raise RuleFailure(selectable.selectable)
                else:
                    # No table was specified with the wildcard. Assume we're
                    # querying from a nested select in FROM.
                    query_list = SelectCrawler.get(
                        query, query.selectables[0].selectable
                    )
                    for o in query_list:
                        if isinstance(o, Query):
                            self._analyze_result_columns(o)
                            return
                    self.logger.debug(
                        f'Query target "{query.selectables[0].selectable.raw}" has no '
                        "targets. Generating warning."
                    )
                    raise RuleFailure(query.selectables[0].selectable)

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Outermost query should produce known number of columns."""
        if not FunctionalContext(context).parent_stack.any(sp.is_type(*_START_TYPES)):
            crawler = SelectCrawler(context.segment, context.dialect)

            # Begin analysis at the outer query.
            if crawler.query_tree:
                try:
                    return self._analyze_result_columns(crawler.query_tree)
                except RuleFailure as e:
                    return LintResult(anchor=e.anchor)
        return None
Пример #24
0
class Rule_L041(BaseRule):
    """``SELECT`` modifiers (e.g. ``DISTINCT``) must be on the same line as ``SELECT``.

    **Anti-pattern**

    .. code-block:: sql

        select
            distinct a,
            b
        from x


    **Best practice**

    .. code-block:: sql

        select distinct
            a,
            b
        from x

    """

    groups = ("all", "core")
    crawl_behaviour = SegmentSeekerCrawler({"select_clause"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Select clause modifiers must appear on same line as SELECT."""
        # We only care about select_clause.
        assert context.segment.is_type("select_clause")

        # Get children of select_clause and the corresponding select keyword.
        child_segments = FunctionalContext(context).segment.children()
        select_keyword = child_segments[0]

        # See if we have a select_clause_modifier.
        select_clause_modifier_seg = child_segments.first(
            sp.is_type("select_clause_modifier"))

        # Rule doesn't apply if there's no select clause modifier.
        if not select_clause_modifier_seg:
            return None

        select_clause_modifier = select_clause_modifier_seg[0]

        # Are there any newlines between the select keyword
        # and the select clause modifier.
        leading_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_keyword,
        )

        # Rule doesn't apply if select clause modifier
        # is already on the same line as the select keyword.
        if not leading_newline_segments:
            return None

        # We should check if there is whitespace before the select clause modifier
        # and remove this during the lint fix.
        leading_whitespace_segments = child_segments.select(
            select_if=sp.is_type("whitespace"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_keyword,
        )

        # We should also check if the following select clause element
        # is on the same line as the select clause modifier.
        trailing_newline_segments = child_segments.select(
            select_if=sp.is_type("newline"),
            loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()),
            start_seg=select_clause_modifier,
        )

        # We will insert these segments directly after the select keyword.
        edit_segments = [
            WhitespaceSegment(),
            select_clause_modifier,
        ]
        if not trailing_newline_segments:
            # if the first select clause element is on the same line
            # as the select clause modifier then also insert a newline.
            edit_segments.append(NewlineSegment())

        fixes = []
        # Move select clause modifier after select keyword.
        fixes.append(
            LintFix.create_after(
                anchor_segment=select_keyword,
                edit_segments=edit_segments,
            ))

        # Delete original newlines and whitespace between select keyword
        # and select clause modifier.

        # If there is not a newline after the select clause modifier then delete
        # newlines between the select keyword and select clause modifier.
        if not trailing_newline_segments:
            fixes.extend(LintFix.delete(s) for s in leading_newline_segments)
        # If there is a newline after the select clause modifier then delete both the
        # newlines and whitespace between the select keyword and select clause modifier.
        else:
            fixes.extend(
                LintFix.delete(s) for s in leading_newline_segments +
                leading_whitespace_segments)
        # Delete the original select clause modifier.
        fixes.append(LintFix.delete(select_clause_modifier))

        # If there is whitespace (on the same line) after the select clause modifier
        # then also delete this.
        trailing_whitespace_segments = child_segments.select(
            select_if=sp.is_whitespace(),
            loop_while=sp.or_(sp.is_type("whitespace"), sp.is_meta()),
            start_seg=select_clause_modifier,
        )
        if trailing_whitespace_segments:
            fixes.extend(
                (LintFix.delete(s) for s in trailing_whitespace_segments))

        return LintResult(
            anchor=context.segment,
            fixes=fixes,
        )
Пример #25
0
class Rule_L047(BaseRule):
    """Use consistent syntax to express "count number of rows".

    Note:
        If both ``prefer_count_1`` and ``prefer_count_0`` are set to true
        then ``prefer_count_1`` has precedence.

    ``COUNT(*)``, ``COUNT(1)``, and even ``COUNT(0)`` are equivalent syntaxes
    in many SQL engines due to optimizers interpreting these instructions as
    "count number of rows in result".

    The ANSI-92_ spec mentions the ``COUNT(*)`` syntax specifically as
    having a special meaning:

        If COUNT(*) is specified, then
        the result is the cardinality of T.

    So by default, `SQLFluff` enforces the consistent use of ``COUNT(*)``.

    If the SQL engine you work with, or your team, prefers ``COUNT(1)`` or
    ``COUNT(0)`` over ``COUNT(*)``, you can configure this rule to consistently
    enforce your preference.

    .. _ANSI-92: http://msdn.microsoft.com/en-us/library/ms175997.aspx

    **Anti-pattern**

    .. code-block:: sql

        select
            count(1)
        from table_a

    **Best practice**

    Use ``count(*)`` unless specified otherwise by config ``prefer_count_1``,
    or ``prefer_count_0`` as preferred.

    .. code-block:: sql

        select
            count(*)
        from table_a

    """

    groups = ("all", "core")
    config_keywords = ["prefer_count_1", "prefer_count_0"]
    crawl_behaviour = SegmentSeekerCrawler({"function"})

    def _eval(self, context: RuleContext) -> Optional[LintResult]:
        """Find rule violations and provide fixes."""
        # Config type hints
        self.prefer_count_0: bool
        self.prefer_count_1: bool

        if (
                # We already know we're in a function because of the crawl_behaviour
                context.segment.get_child("function_name").raw_upper == "COUNT"
        ):
            # Get bracketed content
            f_content = (FunctionalContext(context).segment.children(
                sp.is_type("bracketed")).children(
                    sp.and_(
                        sp.not_(sp.is_meta()),
                        sp.not_(
                            sp.is_type("start_bracket", "end_bracket",
                                       "whitespace", "newline")),
                    )))
            if len(f_content) != 1:  # pragma: no cover
                return None

            preferred = "*"
            if self.prefer_count_1:
                preferred = "1"
            elif self.prefer_count_0:
                preferred = "0"

            if f_content[0].is_type("star") and (self.prefer_count_1
                                                 or self.prefer_count_0):
                return LintResult(
                    anchor=context.segment,
                    fixes=[
                        LintFix.replace(
                            f_content[0],
                            [
                                f_content[0].edit(f_content[0].raw.replace(
                                    "*", preferred))
                            ],
                        ),
                    ],
                )

            if f_content[0].is_type("expression"):
                expression_content = [
                    seg for seg in f_content[0].segments if not seg.is_meta
                ]

                if (len(expression_content) == 1
                        and expression_content[0].is_type("literal")
                        and expression_content[0].raw in ["0", "1"]
                        and expression_content[0].raw != preferred):
                    return LintResult(
                        anchor=context.segment,
                        fixes=[
                            LintFix.replace(
                                expression_content[0],
                                [
                                    expression_content[0].edit(
                                        expression_content[0].raw.replace(
                                            expression_content[0].raw,
                                            preferred)),
                                ],
                            ),
                        ],
                    )
        return None
Пример #26
0
class Rule_L028(BaseRule):
    """References should be consistent in statements with a single table.

    .. note::
        For BigQuery, Hive and Redshift this rule is disabled by default.
        This is due to historical false positives associated with STRUCT data types.
        This default behaviour may be changed in the future.
        The rule can be enabled with the ``force_enable = True`` flag.

    "consistent" will be fixed to "qualified" if inconsistency is found.

    **Anti-pattern**

    In this example, only the field ``b`` is referenced.

    .. code-block:: sql

        SELECT
            a,
            foo.b
        FROM foo

    **Best practice**

    Add or remove references to all fields.

    .. code-block:: sql

        SELECT
            a,
            b
        FROM foo

        -- Also good

        SELECT
            foo.a,
            foo.b
        FROM foo

    """

    groups = ("all", )
    config_keywords = [
        "single_table_references",
        "force_enable",
    ]
    crawl_behaviour = SegmentSeekerCrawler(set(_START_TYPES))
    _is_struct_dialect = False
    _dialects_with_structs = ["bigquery", "hive", "redshift"]
    # This could be turned into an option
    _fix_inconsistent_to = "qualified"

    def _eval(self, context: RuleContext) -> EvalResultType:
        """Override base class for dialects that use structs, or SELECT aliases."""
        # Config type hints
        self.force_enable: bool
        # Some dialects use structs (e.g. column.field) which look like
        # table references and so incorrectly trigger this rule.
        if (context.dialect.name in self._dialects_with_structs
                and not self.force_enable):
            return LintResult()

        if context.dialect.name in self._dialects_with_structs:
            self._is_struct_dialect = True

        if not FunctionalContext(context).parent_stack.any(
                sp.is_type(*_START_TYPES)):
            crawler = SelectCrawler(context.segment, context.dialect)
            if crawler.query_tree:
                # Recursively visit and check each query in the tree.
                return list(self._visit_queries(crawler.query_tree))
        return None

    def _visit_queries(self, query: Query) -> Iterator[LintResult]:
        if query.selectables:
            select_info = query.selectables[0].select_info
            # How many table names are visible from here? If more than one then do
            # nothing.
            if select_info and len(select_info.table_aliases) == 1:
                fixable = True
                # :TRICKY: Subqueries in the column list of a SELECT can see tables
                # in the FROM list of the containing query. Thus, count tables at
                # the *parent* query level.
                table_search_root = query.parent if query.parent else query
                query_list = (SelectCrawler.get(
                    table_search_root,
                    table_search_root.selectables[0].selectable)
                              if table_search_root.selectables else [])
                filtered_query_list = [
                    q for q in query_list if isinstance(q, str)
                ]
                if len(filtered_query_list) != 1:
                    # If more than one table name is visible, check for and report
                    # potential lint warnings, but don't generate fixes, because
                    # fixes are unsafe if there's more than one table visible.
                    fixable = False
                yield from _check_references(
                    select_info.table_aliases,
                    select_info.standalone_aliases,
                    select_info.reference_buffer,
                    select_info.col_aliases,
                    self.single_table_references,  # type: ignore
                    self._is_struct_dialect,
                    self._fix_inconsistent_to,
                    fixable,
                )
        for child in query.children:
            yield from self._visit_queries(child)