def _get_subsequent_whitespace( self, context, ) -> Tuple[Optional[BaseSegment], Optional[BaseSegment]]: """Search forwards through the raw segments for subsequent whitespace. Return a tuple of both the trailing whitespace segment and the first non-whitespace segment discovered. """ # Get all raw segments. "raw_segments" is appropriate as the # only segments we can care about are comma, whitespace, # newline, and comment, which are all raw. Using the # raw_segments allows us to account for possible unexpected # parse tree structures resulting from other rule fixes. raw_segments = FunctionalContext(context).raw_segments # Start after the current comma within the list. Get all the # following whitespace. following_segments = raw_segments.select( loop_while=sp.or_(sp.is_meta(), sp.is_type("whitespace")), start_seg=context.segment, ) subsequent_whitespace = following_segments.last( sp.is_type("whitespace")) try: return ( subsequent_whitespace[0] if subsequent_whitespace else None, raw_segments[raw_segments.index(context.segment) + len(following_segments) + 1], ) except IndexError: # If we find ourselves here it's all whitespace (or nothing) to the # end of the file. This can only happen in bigquery (see # test_pass_bigquery_trailing_comma). return subsequent_whitespace, None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Trailing commas within select clause.""" # Config type hints self.select_clause_trailing_comma: str segment = FunctionalContext(context).segment children = segment.children() # Iterate content to find last element last_content: BaseSegment = children.last(sp.is_code())[0] # What mode are we in? if self.select_clause_trailing_comma == "forbid": # Is it a comma? if last_content.is_type("comma"): return LintResult( anchor=last_content, fixes=[LintFix.delete(last_content)], description="Trailing comma in select statement forbidden", ) elif self.select_clause_trailing_comma == "require": if not last_content.is_type("comma"): new_comma = SymbolSegment(",", type="comma") return LintResult( anchor=last_content, fixes=[ LintFix.replace(last_content, [last_content, new_comma]) ], description="Trailing comma in select statement required", ) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Find rule violations and provide fixes. 0. Look for a case expression 1. Look for "ELSE" 2. Mark "ELSE" for deletion (populate "fixes") 3. Backtrack and mark all newlines/whitespaces for deletion 4. Look for a raw "NULL" segment 5.a. The raw "NULL" segment is found, we mark it for deletion and return 5.b. We reach the end of case when without matching "NULL": the rule passes """ assert context.segment.is_type("case_expression") children = FunctionalContext(context).segment.children() else_clause = children.first(sp.is_type("else_clause")) # Does the "ELSE" have a "NULL"? NOTE: Here, it's safe to look for # "NULL", as an expression would *contain* NULL but not be == NULL. if else_clause and else_clause.children( lambda child: child.raw_upper == "NULL"): # Found ELSE with NULL. Delete the whole else clause as well as # indents/whitespaces/meta preceding the ELSE. :TRICKY: Note # the use of reversed() to make select() effectively search in # reverse. before_else = children.reversed().select( start_seg=else_clause[0], loop_while=sp.or_(sp.is_name("whitespace", "newline"), sp.is_meta()), ) return LintResult( anchor=context.segment, fixes=[LintFix.delete(else_clause[0])] + [LintFix.delete(seg) for seg in before_else], ) return None
def _eval(self, context: RuleContext) -> LintResult: """Function name not immediately followed by bracket. Look for Function Segment with anything other than the function name before brackets """ segment = FunctionalContext(context).segment # We only trigger on start_bracket (open parenthesis) assert segment.all(sp.is_type("function")) children = segment.children() function_name = children.first(sp.is_type("function_name"))[0] start_bracket = children.first(sp.is_type("bracketed"))[0] intermediate_segments = children.select(start_seg=function_name, stop_seg=start_bracket) if intermediate_segments: # It's only safe to fix if there is only whitespace # or newlines in the intervening section. if intermediate_segments.all(sp.is_type("whitespace", "newline")): return LintResult( anchor=intermediate_segments[0], fixes=[ LintFix.delete(seg) for seg in intermediate_segments ], ) else: # It's not all whitespace, just report the error. return LintResult(anchor=intermediate_segments[0], ) return LintResult()
def _eval(self, context: RuleContext) -> LintResult: """Nested CASE statement in ELSE clause could be flattened.""" segment = FunctionalContext(context).segment assert segment.select(sp.is_type("case_expression")) case1_children = segment.children() case1_last_when = case1_children.last(sp.is_type("when_clause")).get() case1_else_clause = case1_children.select(sp.is_type("else_clause")) case1_else_expressions = case1_else_clause.children( sp.is_type("expression")) expression_children = case1_else_expressions.children() case2 = expression_children.select(sp.is_type("case_expression")) # The len() checks below are for safety, to ensure the CASE inside # the ELSE is not part of a larger expression. In that case, it's # not safe to simplify in this way -- we'd be deleting other code. if (not case1_last_when or len(case1_else_expressions) > 1 or len(expression_children) > 1 or not case2): return LintResult() # We can assert that this exists because of the previous check. assert case1_last_when # We can also assert that we'll also have an else clause because # otherwise the case2 check above would fail. case1_else_clause_seg = case1_else_clause.get() assert case1_else_clause_seg # Delete stuff between the last "WHEN" clause and the "ELSE" clause. case1_to_delete = case1_children.select(start_seg=case1_last_when, stop_seg=case1_else_clause_seg) # Delete the nested "CASE" expression. fixes = case1_to_delete.apply(lambda seg: LintFix.delete(seg)) # Determine the indentation to use when we move the nested "WHEN" # and "ELSE" clauses, based on the indentation of case1_last_when. # If no whitespace segments found, use default indent. indent = (case1_children.select( stop_seg=case1_last_when).reversed().select( sp.is_type("whitespace"))) indent_str = "".join(seg.raw for seg in indent) if indent else self.indent # Move the nested "when" and "else" clauses after the last outer # "when". nested_clauses = case2.children( sp.is_type("when_clause", "else_clause")) create_after_last_when = nested_clauses.apply( lambda seg: [NewlineSegment(), WhitespaceSegment(indent_str), seg]) segments = [ item for sublist in create_after_last_when for item in sublist ] fixes.append( LintFix.create_after(case1_last_when, segments, source=segments)) # Delete the outer "else" clause. fixes.append(LintFix.delete(case1_else_clause_seg)) return LintResult(case2[0], fixes=fixes)
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Identify aliases in from clause and join conditions. Find base table, table expressions in join, and other expressions in select clause and decide if it's needed to report them. """ # Config type hints self.force_enable: bool # Issue 2810: BigQuery has some tricky expectations (apparently not # documented, but subject to change, e.g.: # https://www.reddit.com/r/bigquery/comments/fgk31y/new_in_bigquery_no_more_backticks_around_table/) # about whether backticks are required (and whether the query is valid # or not, even with them), depending on whether the GCP project name is # present, or just the dataset name. Since SQLFluff doesn't have access # to BigQuery when it is looking at the query, it would be complex for # this rule to do the right thing. For now, the rule simply disables # itself. if (context.dialect.name in self._dialects_disabled_by_default and not self.force_enable): return LintResult() assert context.segment.is_type("select_statement") children = FunctionalContext(context).segment.children() from_clause_segment = children.select( sp.is_type("from_clause")).first() base_table = (from_clause_segment.children( sp.is_type("from_expression")).first().children( sp.is_type("from_expression_element")).first().children( sp.is_type("table_expression")).first().children( sp.is_type("object_reference")).first()) if not base_table: return None # A buffer for all table expressions in join conditions from_expression_elements = [] column_reference_segments = [] after_from_clause = children.select(start_seg=from_clause_segment[0]) for clause in from_clause_segment + after_from_clause: for from_expression_element in clause.recursive_crawl( "from_expression_element"): from_expression_elements.append(from_expression_element) for column_reference in clause.recursive_crawl("column_reference"): column_reference_segments.append(column_reference) return (self._lint_aliases_in_join( base_table[0] if base_table else None, from_expression_elements, column_reference_segments, context.segment, ) or None)
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Identify aliases in from clause and join conditions. Find base table, table expressions in join, and other expressions in select clause and decide if it's needed to report them. """ self.min_alias_length: Optional[int] self.max_alias_length: Optional[int] assert context.segment.is_type("select_statement") children = FunctionalContext(context).segment.children() from_expression_elements = children.recursive_crawl("from_expression_element") return self._lint_aliases(from_expression_elements) or None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Ambiguous use of DISTINCT in select statement with GROUP BY.""" segment = FunctionalContext(context).segment # We know it's a select_statement from the seeker crawler assert segment.all(sp.is_type("select_statement")) # Do we have a group by clause if segment.children(sp.is_type("groupby_clause")): # Do we have the "DISTINCT" keyword in the select clause distinct = (segment.children(sp.is_type("select_clause")).children( sp.is_type("select_clause_modifier")).children( sp.is_type("keyword")).select(sp.is_name("distinct"))) if distinct: return LintResult(anchor=distinct[0]) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Use ``!=`` instead of ``<>`` for "not equal to" comparison.""" # Only care about not_equal_to segments. We should only get # comparison operator types from the crawler, but not all will # be "not_equal_to". if context.segment.name != "not_equal_to": return None # Get the comparison operator children raw_comparison_operators = ( FunctionalContext(context) .segment.children() .select(select_if=sp.is_type("raw_comparison_operator")) ) # Only care about ``<>`` if [r.raw for r in raw_comparison_operators] != ["<", ">"]: return None # Provide a fix and replace ``<>`` with ``!=`` # As each symbol is a separate symbol this is done in two steps: # 1. Replace < with ! # 2. Replace > with = fixes = [ LintFix.replace( raw_comparison_operators[0], [SymbolSegment(raw="!", type="raw_comparison_operator")], ), LintFix.replace( raw_comparison_operators[1], [SymbolSegment(raw="=", type="raw_comparison_operator")], ), ] return LintResult(context.segment, fixes)
def _eval(self, context: RuleContext) -> EvalResultType: # Config type hints self.force_enable: bool if (context.dialect.name in self._dialects_disabled_by_default and not self.force_enable): return LintResult() violations: List[LintResult] = [] if not FunctionalContext(context).parent_stack.any( sp.is_type(*_START_TYPES)): dml_target_table: Optional[Tuple[str, ...]] = None self.logger.debug("Trigger on: %s", context.segment) if not context.segment.is_type("select_statement"): # Extract first table reference. This will be the target # table in a DML statement. table_reference = next( context.segment.recursive_crawl("table_reference"), None) if table_reference: dml_target_table = self._table_ref_as_tuple( table_reference) self.logger.debug("DML Reference Table: %s", dml_target_table) # Verify table references in any SELECT statements found in or # below context.segment in the parser tree. crawler = SelectCrawler(context.segment, context.dialect, query_class=L026Query) query: L026Query = cast(L026Query, crawler.query_tree) if query: self._analyze_table_references(query, dml_target_table, context.dialect, violations) return violations or None
def _eval(self, context: RuleContext) -> Optional[LintResult]: # T-SQL supports alternative alias expressions for L012 # select alias = value # instead of # select value as alias # Recognise this and exit early if FunctionalContext(context).segment.children()[-1].raw == "=": return None return super()._eval(context)
def _eval(self, context: RuleContext): self.wildcard_policy: str assert context.segment.is_type("select_clause") select_targets_info = self._get_indexes(context) select_clause = FunctionalContext(context).segment wildcards = select_clause.children( sp.is_type("select_clause_element") ).children(sp.is_type("wildcard_expression")) has_wildcard = bool(wildcards) if len(select_targets_info.select_targets) == 1 and ( not has_wildcard or self.wildcard_policy == "single" ): return self._eval_single_select_target_element( select_targets_info, context, ) elif len(select_targets_info.select_targets): return self._eval_multiple_select_target_elements( select_targets_info, context.segment )
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Outermost query should produce known number of columns.""" if not FunctionalContext(context).parent_stack.any(sp.is_type(*_START_TYPES)): crawler = SelectCrawler(context.segment, context.dialect) # Begin analysis at the outer query. if crawler.query_tree: try: return self._analyze_result_columns(crawler.query_tree) except RuleFailure as e: return LintResult(anchor=e.anchor) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Looking for DISTINCT before a bracket. Look for DISTINCT keyword immediately followed by open parenthesis. """ # We trigger on `select_clause` and look for `select_clause_modifier` assert context.segment.is_type("select_clause") children = FunctionalContext(context).segment.children() modifier = children.select(sp.is_type("select_clause_modifier")) first_element = children.select( sp.is_type("select_clause_element")).first() if not modifier or not first_element: return None # is the first element only an expression with only brackets? expression = (first_element.children(sp.is_type("expression")).first() or first_element) bracketed = expression.children(sp.is_type("bracketed")).first() if bracketed: fixes = [] # If there's nothing else in the expression, remove the brackets. if len(expression[0].segments) == 1: # Remove the brackets and strip any meta segments. fixes.append( LintFix.replace( bracketed[0], self.filter_meta(bracketed[0].segments)[1:-1]), ) # If no whitespace between DISTINCT and expression, add it. if not children.select(sp.is_whitespace(), start_seg=modifier[0], stop_seg=first_element[0]): fixes.append( LintFix.create_before( first_element[0], [WhitespaceSegment()], )) # If no fixes, no problem. if fixes: return LintResult(anchor=modifier[0], fixes=fixes) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Column expression without alias. Use explicit `AS` clause. We look for the select_clause_element segment, and then evaluate whether it has an alias segment or not and whether the expression is complicated enough. `parent_stack` is to assess how many other elements there are. """ functional_context = FunctionalContext(context) segment = functional_context.segment children = segment.children() # If we have an alias its all good if children.any(sp.is_type("alias_expression")): return None # Ignore if it's a function with EMITS clause as EMITS is equivalent to AS if (children.select(sp.is_type("function")).children().select( sp.is_type("emits_segment"))): return None parent_stack = functional_context.parent_stack # Ignore if it is part of a CTE with column names if (parent_stack.last( sp.is_type("common_table_expression")).children().any( sp.is_type("cte_column_list"))): return None select_clause_children = children.select(sp.not_(sp.is_name("star"))) is_complex_clause = _recursively_check_is_complex( select_clause_children) if not is_complex_clause: return None # No fixes, because we don't know what the alias should be, # the user should document it themselves. if self.allow_scalar: # type: ignore # Check *how many* elements/columns there are in the select # statement. If this is the only one, then we won't # report an error. immediate_parent = parent_stack.last() elements = immediate_parent.children( sp.is_type("select_clause_element")) num_elements = len(elements) if num_elements > 1: return LintResult(anchor=context.segment) return None return LintResult(anchor=context.segment)
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Files must end with a single trailing newline. We only care about the segment and the siblings which come after it for this rule, we discard the others into the kwargs argument. """ # We only care about the final segment of the parse tree. parent_stack, segment = get_last_segment(FunctionalContext(context).segment) self.logger.debug("Found last segment as: %s", segment) trailing_newlines = Segments(*get_trailing_newlines(context.segment)) trailing_literal_newlines = trailing_newlines self.logger.debug( "Untemplated trailing newlines: %s", trailing_literal_newlines ) if context.templated_file: trailing_literal_newlines = trailing_newlines.select( loop_while=lambda seg: sp.templated_slices( seg, context.templated_file ).all(tsp.is_slice_type("literal")) ) self.logger.debug("Templated trailing newlines: %s", trailing_literal_newlines) if not trailing_literal_newlines: # We make an edit to create this segment after the child of the FileSegment. if len(parent_stack) == 1: fix_anchor_segment = segment[0] else: fix_anchor_segment = parent_stack[1] self.logger.debug("Anchor on: %s", fix_anchor_segment) return LintResult( anchor=segment[0], fixes=[ LintFix.create_after( fix_anchor_segment, [NewlineSegment()], ) ], ) elif len(trailing_literal_newlines) > 1: # Delete extra newlines. return LintResult( anchor=segment[0], fixes=[LintFix.delete(d) for d in trailing_literal_newlines[1:]], ) else: # Single newline, no need for fix. return None
def _eval(self, context: RuleContext) -> LintResult: """Unnecessary trailing whitespace. Look for newline segments, and then evaluate what it was preceded by. """ if len(context.raw_stack) > 0 and context.raw_stack[-1].is_type( "whitespace"): # Look for a newline (or file end), which is preceded by whitespace deletions = ( FunctionalContext(context).raw_stack.reversed().select( loop_while=sp.is_type("whitespace"))) # NOTE: The presence of a loop marker should prevent false # flagging of newlines before jinja loop tags. return LintResult( anchor=deletions[-1], fixes=[LintFix.delete(d) for d in deletions], ) return LintResult()
def _eval(self, context: RuleContext) -> EvalResultType: """Override base class for dialects that use structs, or SELECT aliases.""" # Config type hints self.force_enable: bool # Some dialects use structs (e.g. column.field) which look like # table references and so incorrectly trigger this rule. if (context.dialect.name in self._dialects_with_structs and not self.force_enable): return LintResult() if context.dialect.name in self._dialects_with_structs: self._is_struct_dialect = True if not FunctionalContext(context).parent_stack.any( sp.is_type(*_START_TYPES)): crawler = SelectCrawler(context.segment, context.dialect) if crawler.query_tree: # Recursively visit and check each query in the tree. return list(self._visit_queries(crawler.query_tree)) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Unnecessary CASE statement.""" # Look for CASE expression. if context.segment.segments[0].raw_upper == "CASE": # Find all 'WHEN' clauses and the optional 'ELSE' clause. children = FunctionalContext(context).segment.children() when_clauses = children.select(sp.is_type("when_clause")) else_clauses = children.select(sp.is_type("else_clause")) # Can't fix if multiple WHEN clauses. if len(when_clauses) > 1: return None # Find condition and then expressions. condition_expression = when_clauses.children( sp.is_type("expression"))[0] then_expression = when_clauses.children( sp.is_type("expression"))[1] # Method 1: Check if THEN/ELSE expressions are both Boolean and can # therefore be reduced. if else_clauses: else_expression = else_clauses.children( sp.is_type("expression"))[0] upper_bools = ["TRUE", "FALSE"] if ((then_expression.raw_upper in upper_bools) and (else_expression.raw_upper in upper_bools) and (then_expression.raw_upper != else_expression.raw_upper)): coalesce_arg_1: BaseSegment = condition_expression coalesce_arg_2: BaseSegment = KeywordSegment("false") preceding_not = then_expression.raw_upper == "FALSE" fixes = self._coalesce_fix_list( context, coalesce_arg_1, coalesce_arg_2, preceding_not, ) return LintResult( anchor=condition_expression, fixes=fixes, description="Unnecessary CASE statement. " "Use COALESCE function instead.", ) # Method 2: Check if the condition expression is comparing a column # reference to NULL and whether that column reference is also in either the # THEN/ELSE expression. We can only apply this method when there is only # one condition in the condition expression. condition_expression_segments_raw = { segment.raw_upper for segment in condition_expression.segments } if {"IS", "NULL"}.issubset(condition_expression_segments_raw) and ( not condition_expression_segments_raw.intersection( {"AND", "OR"})): # Check if the comparison is to NULL or NOT NULL. is_not_prefix = "NOT" in condition_expression_segments_raw # Locate column reference in condition expression. column_reference_segment = ( Segments(condition_expression).children( sp.is_type("column_reference")).get()) # Return None if none found (this condition does not apply to functions) if not column_reference_segment: return None if else_clauses: else_expression = else_clauses.children( sp.is_type("expression"))[0] # Check if we can reduce the CASE expression to a single coalesce # function. if (not is_not_prefix and column_reference_segment.raw_upper == else_expression.raw_upper): coalesce_arg_1 = else_expression coalesce_arg_2 = then_expression elif (is_not_prefix and column_reference_segment.raw_upper == then_expression.raw_upper): coalesce_arg_1 = then_expression coalesce_arg_2 = else_expression else: return None if coalesce_arg_2.raw_upper == "NULL": # Can just specify the column on it's own # rather than using a COALESCE function. return LintResult( anchor=condition_expression, fixes=self._column_only_fix_list( context, column_reference_segment, ), description="Unnecessary CASE statement. " f"Just use column '{column_reference_segment.raw}'.", ) return LintResult( anchor=condition_expression, fixes=self._coalesce_fix_list( context, coalesce_arg_1, coalesce_arg_2, ), description="Unnecessary CASE statement. " "Use COALESCE function instead.", ) elif column_reference_segment.raw_upper == then_expression.raw_upper: # Can just specify the column on it's own # rather than using a COALESCE function. # In this case no ELSE statement is equivalent to ELSE NULL. return LintResult( anchor=condition_expression, fixes=self._column_only_fix_list( context, column_reference_segment, ), description="Unnecessary CASE statement. " f"Just use column '{column_reference_segment.raw}'.", ) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Find rule violations and provide fixes.""" # Config type hints self.prefer_count_0: bool self.prefer_count_1: bool if ( # We already know we're in a function because of the crawl_behaviour context.segment.get_child("function_name").raw_upper == "COUNT" ): # Get bracketed content f_content = (FunctionalContext(context).segment.children( sp.is_type("bracketed")).children( sp.and_( sp.not_(sp.is_meta()), sp.not_( sp.is_type("start_bracket", "end_bracket", "whitespace", "newline")), ))) if len(f_content) != 1: # pragma: no cover return None preferred = "*" if self.prefer_count_1: preferred = "1" elif self.prefer_count_0: preferred = "0" if f_content[0].is_type("star") and (self.prefer_count_1 or self.prefer_count_0): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( f_content[0], [ f_content[0].edit(f_content[0].raw.replace( "*", preferred)) ], ), ], ) if f_content[0].is_type("expression"): expression_content = [ seg for seg in f_content[0].segments if not seg.is_meta ] if (len(expression_content) == 1 and expression_content[0].is_type("literal") and expression_content[0].raw in ["0", "1"] and expression_content[0].raw != preferred): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( expression_content[0], [ expression_content[0].edit( expression_content[0].raw.replace( expression_content[0].raw, preferred)), ], ), ], ) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Inconsistent column references in GROUP BY/ORDER BY clauses.""" # Config type hints self.group_by_and_order_by_style: str # We only care about GROUP BY/ORDER BY clauses. assert context.segment.is_type("groupby_clause", "orderby_clause") # Ignore Windowing clauses if FunctionalContext(context).parent_stack.any( sp.is_type(*self._ignore_types)): return LintResult(memory=context.memory) # Look at child segments and map column references to either the implicit or # explicit category. # N.B. segment names are used as the numeric literal type is 'raw', so best to # be specific with the name. column_reference_category_map = { "column_reference": "explicit", "expression": "explicit", "numeric_literal": "implicit", } column_reference_category_set = { column_reference_category_map[segment.get_type()] for segment in context.segment.segments if segment.is_type(*column_reference_category_map.keys()) } # If there are no column references then just return if not column_reference_category_set: # pragma: no cover return LintResult(memory=context.memory) if self.group_by_and_order_by_style == "consistent": # If consistent naming then raise lint error if either: if len(column_reference_category_set) > 1: # 1. Both implicit and explicit column references are found in the same # clause. return LintResult( anchor=context.segment, memory=context.memory, ) else: # 2. A clause is found to contain column name references that # contradict the precedent set in earlier clauses. current_group_by_order_by_convention = ( column_reference_category_set.pop()) prior_group_by_order_by_convention = context.memory.get( "prior_group_by_order_by_convention") if prior_group_by_order_by_convention and ( prior_group_by_order_by_convention != current_group_by_order_by_convention): return LintResult( anchor=context.segment, memory=context.memory, ) context.memory[ "prior_group_by_order_by_convention"] = current_group_by_order_by_convention else: # If explicit or implicit naming then raise lint error # if the opposite reference type is detected. if any(category != self.group_by_and_order_by_style for category in column_reference_category_set): return LintResult( anchor=context.segment, memory=context.memory, ) # Return memory for later clauses. return LintResult(memory=context.memory)
def _eval(self, context: RuleContext) -> List[LintResult]: """Set operators should be surrounded by newlines. For any set operator we check if there is any NewLineSegment in the non-code segments preceeding or following it. In particular, as part of this rule we allow multiple NewLineSegments. """ segment = FunctionalContext(context).segment expression = segment.children() set_operator_segments = segment.children(sp.is_type(*self._target_elems)) # We should always find some as children because of the ParentOfSegmentCrawler assert set_operator_segments results: List[LintResult] = [] # If len(set_operator) == 0 this will essentially not run for set_operator in set_operator_segments: preceeding_code = ( expression.reversed().select(start_seg=set_operator).first(sp.is_code()) ) following_code = expression.select(start_seg=set_operator).first( sp.is_code() ) res = { "before": expression.select( start_seg=preceeding_code.get(), stop_seg=set_operator ), "after": expression.select( start_seg=set_operator, stop_seg=following_code.get() ), } newline_before_set_operator = res["before"].first(sp.is_type("newline")) newline_after_set_operator = res["after"].first(sp.is_type("newline")) # If there is a whitespace directly preceeding/following the set operator we # are replacing it with a newline later. preceeding_whitespace = res["before"].first(sp.is_type("whitespace")).get() following_whitespace = res["after"].first(sp.is_type("whitespace")).get() if newline_before_set_operator and newline_after_set_operator: continue elif not newline_before_set_operator and newline_after_set_operator: results.append( LintResult( anchor=set_operator, description=( "Set operators should be surrounded by newlines. " f"Missing newline before set operator {set_operator.raw}." ), fixes=_generate_fixes(whitespace_segment=preceeding_whitespace), ) ) elif newline_before_set_operator and not newline_after_set_operator: results.append( LintResult( anchor=set_operator, description=( "Set operators should be surrounded by newlines. " f"Missing newline after set operator {set_operator.raw}." ), fixes=_generate_fixes(whitespace_segment=following_whitespace), ) ) else: preceeding_whitespace_fixes = _generate_fixes( whitespace_segment=preceeding_whitespace ) following_whitespace_fixes = _generate_fixes( whitespace_segment=following_whitespace ) # make mypy happy assert isinstance(preceeding_whitespace_fixes, Iterable) assert isinstance(following_whitespace_fixes, Iterable) fixes = [] fixes.extend(preceeding_whitespace_fixes) fixes.extend(following_whitespace_fixes) results.append( LintResult( anchor=set_operator, description=( "Set operators should be surrounded by newlines. " "Missing newline before and after set operator " f"{set_operator.raw}." ), fixes=fixes, ) ) return results
def _eval(self, context: RuleContext): """WITH clause closing bracket should be aligned with WITH keyword. Look for a with clause and evaluate the position of closing brackets. """ # We only trigger on start_bracket (open parenthesis) assert context.segment.is_type("with_compound_statement") raw_stack_buff = list(context.raw_stack) # Look for the with keyword for seg in context.segment.segments: if seg.name.lower() == "with": seg_line_no = seg.pos_marker.line_no break else: # pragma: no cover # This *could* happen if the with statement is unparsable, # in which case then the user will have to fix that first. if any(s.is_type("unparsable") for s in context.segment.segments): return LintResult() # If it's parsable but we still didn't find a with, then # we should raise that. raise RuntimeError("Didn't find WITH keyword!") # Find the end brackets for the CTE *query* (i.e. ignore optional # list of CTE columns). cte_end_brackets = IdentitySet() for cte in (FunctionalContext(context).segment.children( sp.is_type("common_table_expression")).iterate_segments()): cte_end_bracket = (cte.children().last( sp.is_type("bracketed")).children().last( sp.is_type("end_bracket"))) if cte_end_bracket: cte_end_brackets.add(cte_end_bracket[0]) for seg in context.segment.iter_segments( expanding=["common_table_expression", "bracketed"], pass_through=True): if seg not in cte_end_brackets: if not seg.is_type("start_bracket"): raw_stack_buff.append(seg) continue if seg.pos_marker.line_no == seg_line_no: # Skip if it's the one-line version. That's ok continue # Is it all whitespace before the bracket on this line? assert seg.pos_marker contains_non_whitespace = False for elem in context.segment.raw_segments: if (cast(PositionMarker, elem.pos_marker).line_no == seg.pos_marker.line_no and cast(PositionMarker, elem.pos_marker).line_pos <= seg.pos_marker.line_pos): if elem is seg: break elif elem.is_type("newline"): contains_non_whitespace = False elif not elem.is_type("dedent") and not elem.is_type( "whitespace"): contains_non_whitespace = True if contains_non_whitespace: # We have to move it to a newline return LintResult( anchor=seg, fixes=[LintFix.create_before( seg, [ NewlineSegment(), ], )], )
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Unnecessary quoted identifier.""" # Config type hints self.prefer_quoted_identifiers: bool self.ignore_words: str self.ignore_words_regex: str self.force_enable: bool # Some dialects allow quotes as PART OF the column name. In other words, # these are two different columns: # - date # - "date" # For safety, disable this rule by default in those dialects. if (context.dialect.name in self._dialects_allowing_quotes_in_column_names and not self.force_enable): return LintResult() # Ignore some segment types if FunctionalContext(context).parent_stack.any( sp.is_type(*self._ignore_types)): return None if self.prefer_quoted_identifiers: context_policy = "naked_identifier" identifier_contents = context.segment.raw else: context_policy = "quoted_identifier" identifier_contents = context.segment.raw[1:-1] # Get the ignore_words_list configuration. try: ignore_words_list = self.ignore_words_list except AttributeError: # First-time only, read the settings from configuration. This is # very slow. ignore_words_list = self._init_ignore_words_list() # Skip if in ignore list if ignore_words_list and identifier_contents.lower( ) in ignore_words_list: return None # Skip if matches ignore regex if self.ignore_words_regex and regex.search(self.ignore_words_regex, identifier_contents): return LintResult(memory=context.memory) # Ignore the segments that are not of the same type as the defined policy above. # Also TSQL has a keyword called QUOTED_IDENTIFIER which maps to the name so # need to explicity check for that. if not context.segment.is_type( context_policy) or context.segment.raw.lower() in ( "quoted_identifier", "naked_identifier", ): return None # Manage cases of identifiers must be quoted first. # Naked identifiers are _de facto_ making this rule fail as configuration forces # them to be quoted. # In this case, it cannot be fixed as which quote to use is dialect dependent if self.prefer_quoted_identifiers: return LintResult( context.segment, description=f"Missing quoted identifier {identifier_contents}.", ) # Now we only deal with NOT forced quoted identifiers configuration # (meaning prefer_quoted_identifiers=False). # Extract contents of outer quotes. quoted_identifier_contents = context.segment.raw[1:-1] # Retrieve NakedIdentifierSegment RegexParser for the dialect. naked_identifier_parser = context.dialect._library[ "NakedIdentifierSegment"] IdentifierSegment = cast( Type[CodeSegment], context.dialect.get_segment("IdentifierSegment")) # Check if quoted_identifier_contents could be a valid naked identifier # and that it is not a reserved keyword. if (regex.fullmatch( naked_identifier_parser.template, quoted_identifier_contents, regex.IGNORECASE, ) is not None) and (regex.fullmatch( naked_identifier_parser.anti_template, quoted_identifier_contents, regex.IGNORECASE, ) is None): return LintResult( context.segment, fixes=[ LintFix.replace( context.segment, [ IdentifierSegment( raw=quoted_identifier_contents, type="naked_identifier", ) ], ) ], description= f"Unnecessary quoted identifier {context.segment.raw}.", ) return None
def _eval(self, context: RuleContext) -> EvalResultType: """Look for non-literal segments.""" assert context.segment.pos_marker if context.segment.is_raw( ) and not context.segment.pos_marker.is_literal(): if not context.memory: memory = set() else: memory = context.memory # Get any templated raw slices. # NOTE: We use this function because a single segment # may include multiple raw templated sections: # e.g. a single identifier with many templated tags. templated_raw_slices = FunctionalContext( context).segment.raw_slices.select( rsp.is_slice_type("templated", "block_start", "block_end")) result = [] # Iterate through any tags found. for raw_slice in templated_raw_slices: stripped = raw_slice.raw.strip() if not stripped or stripped[0] != "{" or stripped[-1] != "}": continue # pragma: no cover self.logger.debug("Tag found @ %s: %r ", context.segment.pos_marker, stripped) # Dedupe using a memory of source indexes. # This is important because several positions in the # templated file may refer to the same position in the # source file and we only want to get one violation. src_idx = raw_slice.source_idx if context.memory and src_idx in context.memory: continue memory.add(src_idx) # Partition and Position tag_pre, ws_pre, inner, ws_post, tag_post = self._get_whitespace_ends( stripped) position = raw_slice.raw.find(stripped[0]) self.logger.debug( "Tag string segments: %r | %r | %r | %r | %r @ %s + %s", tag_pre, ws_pre, inner, ws_post, tag_post, src_idx, position, ) # For the following section, whitespace should be a single # whitespace OR it should contain a newline. pre_fix = None post_fix = None # Check the initial whitespace. if not ws_pre or (ws_pre != " " and "\n" not in ws_pre): pre_fix = " " # Check latter whitespace. if not ws_post or (ws_post != " " and "\n" not in ws_post): post_fix = " " if pre_fix is not None or post_fix is not None: fixed = (tag_pre + (pre_fix or ws_pre) + inner + (post_fix or ws_post) + tag_post) src_fix = [ SourceFix( fixed, slice( src_idx + position, src_idx + position + len(stripped), ), # NOTE: The templated slice here is # going to be a little imprecise, but # the one that really matters is the # source slice. context.segment.pos_marker.templated_slice, ) ] result.append( LintResult( memory=memory, anchor=context.segment, description=f"Jinja tags should have a single " f"whitespace on either side: {stripped}", fixes=[ LintFix.replace( context.segment, [ context.segment.edit( source_fixes=src_fix) ], ) ], )) if result: return result else: return LintResult(memory=memory) return LintResult(memory=context.memory)
def _eval(self, context: RuleContext) -> Optional[LintResult]: # Config type hints self.preferred_quoted_literal_style: str self.force_enable: bool # Only care about quoted literal segments. if not context.segment.is_type("quoted_literal"): return None if not (self.force_enable or context.dialect.name in self._dialects_with_double_quoted_strings): return LintResult(memory=context.memory) # This rule can also cover quoted literals that are partially templated. # I.e. when the quotes characters are _not_ part of the template we can # meaningfully apply this rule. templated_raw_slices = FunctionalContext( context).segment.raw_slices.select(rsp.is_slice_type("templated")) for raw_slice in templated_raw_slices: pos_marker = context.segment.pos_marker # This is to make mypy happy. assert isinstance(pos_marker, PositionMarker) # Check whether the quote characters are inside the template. # For the leading quote we need to account for string prefix characters. leading_quote_inside_template = pos_marker.source_str()[:2].lstrip( self._string_prefix_chars)[0] not in ['"', "'"] trailing_quote_inside_template = pos_marker.source_str( )[-1] not in [ '"', "'", ] # quotes are not entirely outside of a template, nothing we can do if leading_quote_inside_template or trailing_quote_inside_template: return LintResult(memory=context.memory) # If quoting style is set to consistent we use the quoting style of the first # quoted_literal that we encounter. if self.preferred_quoted_literal_style == "consistent": memory = context.memory preferred_quoted_literal_style = memory.get( "preferred_quoted_literal_style") if not preferred_quoted_literal_style: # Getting the quote from LAST character to be able to handle STRING # prefixes preferred_quoted_literal_style = ("double_quotes" if context.segment.raw[-1] == '"' else "single_quotes") memory[ "preferred_quoted_literal_style"] = preferred_quoted_literal_style self.logger.debug( "Preferred string quotes is set to `consistent`. Derived quoting " "style %s from first quoted literal.", preferred_quoted_literal_style, ) else: preferred_quoted_literal_style = self.preferred_quoted_literal_style fixed_string = self._normalize_preferred_quoted_literal_style( context.segment.raw, preferred_quote_char=self._quotes_mapping[ preferred_quoted_literal_style]["preferred_quote_char"], alternate_quote_char=self._quotes_mapping[ preferred_quoted_literal_style]["alternate_quote_char"], ) if fixed_string != context.segment.raw: return LintResult( anchor=context.segment, memory=context.memory, fixes=[ LintFix.replace( context.segment, [ LiteralSegment( raw=fixed_string, type="quoted_literal", ) ], ) ], description=( "Inconsistent use of preferred quote style '" f"{self._quotes_mapping[preferred_quoted_literal_style]['common_name']}" # noqa: E501 f"'. Use {fixed_string} instead of {context.segment.raw}." ), ) return None
def _get_indexes(context: RuleContext): children = FunctionalContext(context).segment.children() select_targets = children.select(sp.is_type("select_clause_element")) first_select_target_idx = children.find(select_targets.get()) selects = children.select(sp.is_keyword("select")) select_idx = children.find(selects.get()) if selects else -1 newlines = children.select(sp.is_type("newline")) first_new_line_idx = children.find(newlines.get()) if newlines else -1 comment_after_select_idx = -1 if newlines: comment_after_select = children.select( sp.is_type("comment"), start_seg=selects.get(), stop_seg=newlines.get(), loop_while=sp.or_( sp.is_type("comment"), sp.is_type("whitespace"), sp.is_meta() ), ) if comment_after_select: comment_after_select_idx = ( children.find(comment_after_select.get()) if comment_after_select else -1 ) first_whitespace_idx = -1 if first_new_line_idx != -1: # TRICKY: Ignore whitespace prior to the first newline, e.g. if # the line with "SELECT" (before any select targets) has trailing # whitespace. segments_after_first_line = children.select( sp.is_type("whitespace"), start_seg=children[first_new_line_idx] ) first_whitespace_idx = children.find(segments_after_first_line.get()) siblings_post = FunctionalContext(context).siblings_post from_segment = siblings_post.first(sp.is_type("from_clause")).first().get() pre_from_whitespace = siblings_post.select( sp.is_type("whitespace"), stop_seg=from_segment ) return SelectTargetsInfo( select_idx, first_new_line_idx, first_select_target_idx, first_whitespace_idx, comment_after_select_idx, select_targets, from_segment, list(pre_from_whitespace), )
def _eval(self, context: RuleContext) -> Optional[List[LintResult]]: """Relational operators should not be used to check for NULL values.""" # Context/motivation for this rule: # https://news.ycombinator.com/item?id=28772289 # https://stackoverflow.com/questions/9581745/sql-is-null-and-null if len(context.segment.segments) <= 2: return None # pragma: no cover # Allow assignments in SET clauses if context.parent_stack and context.parent_stack[-1].is_type( "set_clause_list", "execute_script_statement"): return None # Allow assignments in EXEC clauses if context.segment.is_type("set_clause_list", "execute_script_statement"): return None segment = FunctionalContext(context).segment # Iterate through children of this segment looking for equals or "not # equals". Once found, check if the next code segment is a NULL literal. children = segment.children() operators = segment.children(sp.raw_is("=", "!=", "<>")) if len(operators) == 0: return None self.logger.debug("Operators found: %s", operators) results: List[LintResult] = [] # We may have many operators for operator in operators: self.logger.debug("Children found: %s", children) after_op_list = children.select(start_seg=operator) # If nothing comes after operator then skip if not after_op_list: continue # pragma: no cover null_literal = after_op_list.first(sp.is_code()) # if the next bit of code isnt a NULL then we are good if not null_literal.all(sp.is_type("null_literal")): continue sub_seg = null_literal.get() assert sub_seg, "TypeGuard: Segment must exist" self.logger.debug( "Found NULL literal following equals/not equals @%s: %r", sub_seg.pos_marker, sub_seg.raw, ) edit = _create_base_is_null_sequence( is_upper=sub_seg.raw[0] == "N", operator_raw=operator.raw, ) prev_seg = after_op_list.first().get() next_seg = children.select(stop_seg=operator).last().get() if self._missing_whitespace(prev_seg, before=True): whitespace_segment: CorrectionListType = [WhitespaceSegment()] edit = whitespace_segment + edit if self._missing_whitespace(next_seg, before=False): edit = edit + [WhitespaceSegment()] res = LintResult( anchor=operator, fixes=[LintFix.replace( operator, edit, )], ) results.append(res) return results or None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Select clause modifiers must appear on same line as SELECT.""" # We only care about select_clause. assert context.segment.is_type("select_clause") # Get children of select_clause and the corresponding select keyword. child_segments = FunctionalContext(context).segment.children() select_keyword = child_segments[0] # See if we have a select_clause_modifier. select_clause_modifier_seg = child_segments.first( sp.is_type("select_clause_modifier")) # Rule doesn't apply if there's no select clause modifier. if not select_clause_modifier_seg: return None select_clause_modifier = select_clause_modifier_seg[0] # Are there any newlines between the select keyword # and the select clause modifier. leading_newline_segments = child_segments.select( select_if=sp.is_type("newline"), loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()), start_seg=select_keyword, ) # Rule doesn't apply if select clause modifier # is already on the same line as the select keyword. if not leading_newline_segments: return None # We should check if there is whitespace before the select clause modifier # and remove this during the lint fix. leading_whitespace_segments = child_segments.select( select_if=sp.is_type("whitespace"), loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()), start_seg=select_keyword, ) # We should also check if the following select clause element # is on the same line as the select clause modifier. trailing_newline_segments = child_segments.select( select_if=sp.is_type("newline"), loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()), start_seg=select_clause_modifier, ) # We will insert these segments directly after the select keyword. edit_segments = [ WhitespaceSegment(), select_clause_modifier, ] if not trailing_newline_segments: # if the first select clause element is on the same line # as the select clause modifier then also insert a newline. edit_segments.append(NewlineSegment()) fixes = [] # Move select clause modifier after select keyword. fixes.append( LintFix.create_after( anchor_segment=select_keyword, edit_segments=edit_segments, )) # Delete original newlines and whitespace between select keyword # and select clause modifier. # If there is not a newline after the select clause modifier then delete # newlines between the select keyword and select clause modifier. if not trailing_newline_segments: fixes.extend(LintFix.delete(s) for s in leading_newline_segments) # If there is a newline after the select clause modifier then delete both the # newlines and whitespace between the select keyword and select clause modifier. else: fixes.extend( LintFix.delete(s) for s in leading_newline_segments + leading_whitespace_segments) # Delete the original select clause modifier. fixes.append(LintFix.delete(select_clause_modifier)) # If there is whitespace (on the same line) after the select clause modifier # then also delete this. trailing_whitespace_segments = child_segments.select( select_if=sp.is_whitespace(), loop_while=sp.or_(sp.is_type("whitespace"), sp.is_meta()), start_seg=select_clause_modifier, ) if trailing_whitespace_segments: fixes.extend( (LintFix.delete(s) for s in trailing_whitespace_segments)) return LintResult( anchor=context.segment, fixes=fixes, )
def _eval_single_select_target_element( self, select_targets_info, context: RuleContext ): select_clause = FunctionalContext(context).segment parent_stack = context.parent_stack if ( select_targets_info.select_idx < select_targets_info.first_new_line_idx < select_targets_info.first_select_target_idx ): # Do we have a modifier? select_children = select_clause.children() modifier: Optional[Segments] modifier = select_children.first(sp.is_type("select_clause_modifier")) # Prepare the select clause which will be inserted insert_buff = [ WhitespaceSegment(), select_children[select_targets_info.first_select_target_idx], ] # Check if the modifier is one we care about if modifier: # If it's already on the first line, ignore it. if ( select_children.index(modifier.get()) < select_targets_info.first_new_line_idx ): modifier = None fixes = [ # Delete the first select target from its original location. # We'll add it to the right section at the end, once we know # what to add. LintFix.delete( select_children[select_targets_info.first_select_target_idx], ), ] # If we have a modifier to move: if modifier: # Add it to the insert insert_buff = [WhitespaceSegment(), modifier[0]] + insert_buff modifier_idx = select_children.index(modifier.get()) # Delete the whitespace after it (which is two after, thanks to indent) if ( len(select_children) > modifier_idx + 1 and select_children[modifier_idx + 2].is_whitespace ): fixes += [ LintFix.delete( select_children[modifier_idx + 2], ), ] # Delete the modifier itself fixes += [ LintFix.delete( modifier[0], ), ] # Set the position marker for removing the preceding # whitespace and newline, which we'll use below. start_idx = modifier_idx else: # Set the position marker for removing the preceding # whitespace and newline, which we'll use below. start_idx = select_targets_info.first_select_target_idx if parent_stack and parent_stack[-1].is_type("select_statement"): select_stmt = parent_stack[-1] select_clause_idx = select_stmt.segments.index(select_clause.get()) after_select_clause_idx = select_clause_idx + 1 if len(select_stmt.segments) > after_select_clause_idx: def _fixes_for_move_after_select_clause( stop_seg: BaseSegment, delete_segments: Optional[Segments] = None, add_newline: bool = True, ) -> List[LintFix]: """Cleans up by moving leftover select_clause segments. Context: Some of the other fixes we make in _eval_single_select_target_element() leave leftover child segments that need to be moved to become *siblings* of the select_clause. """ start_seg = ( modifier[0] if modifier else select_children[select_targets_info.first_new_line_idx] ) move_after_select_clause = select_children.select( start_seg=start_seg, stop_seg=stop_seg, ) # :TRICKY: Below, we have a couple places where we # filter to guard against deleting the same segment # multiple times -- this is illegal. # :TRICKY: Use IdentitySet rather than set() since # different segments may compare as equal. all_deletes = IdentitySet( fix.anchor for fix in fixes if fix.edit_type == "delete" ) fixes_ = [] for seg in delete_segments or []: if seg not in all_deletes: fixes.append(LintFix.delete(seg)) all_deletes.add(seg) fixes_ += [ LintFix.delete(seg) for seg in move_after_select_clause if seg not in all_deletes ] fixes_.append( LintFix.create_after( select_clause[0], ([NewlineSegment()] if add_newline else []) + list(move_after_select_clause), ) ) return fixes_ if select_stmt.segments[after_select_clause_idx].is_type("newline"): # Since we're deleting the newline, we should also delete all # whitespace before it or it will add random whitespace to # following statements. So walk back through the segment # deleting whitespace until you get the previous newline, or # something else. to_delete = select_children.reversed().select( loop_while=sp.is_type("whitespace"), start_seg=select_children[start_idx], ) if to_delete: # The select_clause is immediately followed by a # newline. Delete the newline in order to avoid leaving # behind an empty line after fix, *unless* we stopped # due to something other than a newline. delete_last_newline = select_children[ start_idx - len(to_delete) - 1 ].is_type("newline") # Delete the newline if we decided to. if delete_last_newline: fixes.append( LintFix.delete( select_stmt.segments[after_select_clause_idx], ) ) fixes += _fixes_for_move_after_select_clause( to_delete[-1], to_delete ) elif select_stmt.segments[after_select_clause_idx].is_type( "whitespace" ): # The select_clause has stuff after (most likely a comment) # Delete the whitespace immediately after the select clause # so the other stuff aligns nicely based on where the select # clause started. fixes += [ LintFix.delete( select_stmt.segments[after_select_clause_idx], ), ] fixes += _fixes_for_move_after_select_clause( select_children[ select_targets_info.first_select_target_idx ], ) elif select_stmt.segments[after_select_clause_idx].is_type( "dedent" ): # Again let's strip back the whitespace, but simpler # as don't need to worry about new line so just break # if see non-whitespace to_delete = select_children.reversed().select( loop_while=sp.is_type("whitespace"), start_seg=select_children[select_clause_idx - 1], ) if to_delete: fixes += _fixes_for_move_after_select_clause( to_delete[-1], to_delete, # If we deleted a newline, create a newline. any(seg for seg in to_delete if seg.is_type("newline")), ) else: fixes += _fixes_for_move_after_select_clause( select_children[ select_targets_info.first_select_target_idx ], ) if select_targets_info.comment_after_select_idx == -1: fixes += [ # Insert the select_clause in place of the first newline in the # Select statement LintFix.replace( select_children[select_targets_info.first_new_line_idx], insert_buff, ), ] else: # The SELECT is followed by a comment on the same line. In order # to autofix this, we'd need to move the select target between # SELECT and the comment and potentially delete the entire line # where the select target was (if it is now empty). This is # *fairly tricky and complex*, in part because the newline on # the select target's line is several levels higher in the # parser tree. Hence, we currently don't autofix this. Could be # autofixed in the future if/when we have the time. fixes = [] return LintResult( anchor=select_clause.get(), fixes=fixes, ) return None