def _handle_semicolon_newline( self, context: RuleContext, info: SegmentMoveContext) -> Optional[LintResult]: # Adjust before_segment and anchor_segment for preceding inline # comments. Inline comments can contain noqa logic so we need to add the # newline after the inline comment. ( before_segment, anchor_segment, ) = self._handle_preceding_inline_comments(info.before_segment, info.anchor_segment) if (len(before_segment) == 1) and all( s.is_type("newline") for s in before_segment): return None # If preceding segment is not a single newline then delete the old # semi-colon/preceding whitespace and then insert the # semi-colon in the correct location. # This handles an edge case in which an inline comment comes after # the semi-colon. anchor_segment = self._handle_trailing_inline_comments( context, anchor_segment) fixes = [] if anchor_segment is context.segment: fixes.append( LintFix.replace( anchor_segment, [ NewlineSegment(), SymbolSegment(raw=";", type="symbol", name="semicolon"), ], )) else: fixes.extend([ LintFix.replace( anchor_segment, [ anchor_segment, NewlineSegment(), SymbolSegment(raw=";", type="symbol", name="semicolon"), ], ), LintFix.delete(context.segment, ), ]) fixes.extend(LintFix.delete(d) for d in info.whitespace_deletions) return LintResult( anchor=anchor_segment, fixes=fixes, )
def _eval(self, context: RuleContext) -> LintResult: """Look for UNION keyword not immediately followed by DISTINCT or ALL. Note that UNION DISTINCT is valid, rule only applies to bare UNION. The function does this by looking for a segment of type set_operator which has a UNION but no DISTINCT or ALL. Note only some dialects have concept of UNION DISTINCT, so rule is only applied to dialects that are known to support this syntax. """ if context.dialect.name not in [ "ansi", "bigquery", "hive", "mysql", "redshift", ]: return LintResult() if context.segment.is_type("set_operator"): if "union" in context.segment.raw and not ( "ALL" in context.segment.raw.upper() or "DISTINCT" in context.segment.raw.upper()): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( context.segment.segments[0], [ KeywordSegment("union"), WhitespaceSegment(), KeywordSegment("distinct"), ], ) ], ) elif "UNION" in context.segment.raw.upper() and not ( "ALL" in context.segment.raw.upper() or "DISTINCT" in context.segment.raw.upper()): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( context.segment.segments[0], [ KeywordSegment("UNION"), WhitespaceSegment(), KeywordSegment("DISTINCT"), ], ) ], ) return LintResult()
def _eval(self, context: RuleContext) -> Optional[LintResult]: # We only care about commas. if context.segment.name != "comma": return None # Get subsequent whitespace segment and the first non-whitespace segment. subsequent_whitespace, first_non_whitespace = self._get_subsequent_whitespace( context ) if ( not subsequent_whitespace and (first_non_whitespace is not None) and (not first_non_whitespace.is_type("newline")) ): # No trailing whitespace and not followed by a newline, # therefore create a whitespace after the comma. return LintResult( anchor=first_non_whitespace, fixes=[LintFix.create_after(context.segment, [WhitespaceSegment()])], ) elif ( subsequent_whitespace and (subsequent_whitespace.raw != " ") and (first_non_whitespace is not None) and (not first_non_whitespace.is_comment) ): # Excess trailing whitespace therefore edit to only be one space long. return LintResult( anchor=subsequent_whitespace, fixes=[LintFix.replace(subsequent_whitespace, [WhitespaceSegment()])], ) return None
def _coalesce_fix_list( context: RuleContext, coalesce_arg_1: BaseSegment, coalesce_arg_2: BaseSegment, preceding_not: bool = False, ) -> List[LintFix]: """Generate list of fixes to convert CASE statement to COALESCE function.""" # Add coalesce and opening parenthesis. edits = [ KeywordSegment("coalesce"), SymbolSegment("(", name="start_bracket", type="start_bracket"), coalesce_arg_1, SymbolSegment(",", name="comma", type="comma"), WhitespaceSegment(), coalesce_arg_2, SymbolSegment(")", name="end_bracket", type="end_bracket"), ] if preceding_not: not_edits: List[BaseSegment] = [ KeywordSegment("not"), WhitespaceSegment(), ] edits = not_edits + edits fixes = [LintFix.replace( context.segment, edits, )] return fixes
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Use ``COALESCE`` instead of ``IFNULL`` or ``NVL``.""" # We only care about function names. if context.segment.name != "function_name_identifier": return None # Only care if the function is ``IFNULL`` or ``NVL``. if context.segment.raw_upper not in {"IFNULL", "NVL"}: return None # Create fix to replace ``IFNULL`` or ``NVL`` with ``COALESCE``. fix = LintFix.replace( context.segment, [ CodeSegment( raw="COALESCE", name="function_name_identifier", type="function_name_identifier", ) ], ) return LintResult( anchor=context.segment, fixes=[fix], description= f"Use 'COALESCE' instead of '{context.segment.raw_upper}'.", )
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Trailing commas within select clause.""" # Config type hints self.select_clause_trailing_comma: str segment = context.functional.segment children = segment.children() if segment.all(sp.is_type("select_clause")): # Iterate content to find last element last_content: BaseSegment = children.last(sp.is_code())[0] # What mode are we in? if self.select_clause_trailing_comma == "forbid": # Is it a comma? if last_content.is_type("comma"): return LintResult( anchor=last_content, fixes=[LintFix.delete(last_content)], description= "Trailing comma in select statement forbidden", ) elif self.select_clause_trailing_comma == "require": if not last_content.is_type("comma"): new_comma = SymbolSegment(",", name="comma", type="comma") return LintResult( anchor=last_content, fixes=[ LintFix.replace(last_content, [last_content, new_comma]) ], description= "Trailing comma in select statement required", ) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Mixed Tabs and Spaces in single whitespace. Only trigger from whitespace segments if they contain multiple kinds of whitespace. """ # Config type hints self.tab_space_size: int if context.segment.is_type("whitespace"): if " " in context.segment.raw and "\t" in context.segment.raw: if len(context.raw_stack ) == 0 or context.raw_stack[-1].is_type("newline"): # We've got a single whitespace at the beginning of a line. # It's got a mix of spaces and tabs. Replace each tab with # a multiple of spaces return LintResult( anchor=context.segment, fixes=[ LintFix.replace( context.segment, [ context.segment.edit( context.segment.raw.replace( "\t", " " * self.tab_space_size)), ], ), ], ) return None
def _get_fix(self, segment, fixed_raw): """Given a segment found to have a fix, returns a LintFix for it. May be overridden by subclasses, which is useful when the parse tree structure varies from this simple base case. """ return LintFix.replace(segment, [segment.edit(fixed_raw)])
def _mock_crawl(rule, segment, ignore_mask, templated_file=None, *args, **kwargs): # For test__cli__fix_loop_limit_behavior, we mock BaseRule.crawl(), # replacing it with this function. This function generates an infinite # sequence of fixes without ever repeating the same fix. This causes the # linter to hit the loop limit, allowing us to test that behavior. if segment.is_type("comment") and "Comment" in segment.raw: global _fix_counter _fix_counter += 1 fix = LintFix.replace(segment, [CommentSegment(f"-- Comment {_fix_counter}")]) result = LintResult(segment, fixes=[fix]) errors = [] fixes = [] rule._process_lint_result(result, templated_file, ignore_mask, errors, fixes) return ( errors, None, fixes, None, ) else: return _old_crawl(rule, segment, ignore_mask, templated_file=templated_file, *args, **kwargs)
def _eval(self, context): """Stars make newlines.""" if context.segment.is_type("whitespace"): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( context.segment, [WhitespaceSegment(context.segment.raw + " ")]) ], )
def _column_only_fix_list( context: RuleContext, column_reference_segment: BaseSegment, ) -> List[LintFix]: """Generate list of fixes to reduce CASE statement to a single column.""" fixes = [ LintFix.replace( context.segment, [column_reference_segment], ) ] return fixes
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Use ``!=`` instead of ``<>`` for "not equal to" comparison.""" # Only care about not_equal_to segments. if context.segment.name != "not_equal_to": return None # Get the comparison operator children raw_comparison_operators = context.functional.segment.children( ).select(select_if=sp.is_type("raw_comparison_operator")) # Only care about ``<>`` if [r.raw for r in raw_comparison_operators] != ["<", ">"]: return None # Provide a fix and replace ``<>`` with ``!=`` # As each symbol is a separate symbol this is done in two steps: # 1. Replace < with ! # 2. Replace > with = fixes = [ LintFix.replace( raw_comparison_operators[0], [ CodeSegment(raw="!", name="raw_not", type="raw_comparison_operator") ], ), LintFix.replace( raw_comparison_operators[1], [ CodeSegment(raw="=", name="raw_equals", type="raw_comparison_operator") ], ), ] return LintResult(context.segment, fixes)
def _eval(self, context: RuleContext) -> LintResult: """Incorrect indentation found in file.""" # Config type hints self.tab_space_size: int self.indent_unit: str tab = "\t" space = " " correct_indent = self.indent wrong_indent = (tab if self.indent_unit == "space" else space * self.tab_space_size) if (context.segment.is_type("whitespace") and wrong_indent in context.segment.raw): fixes = [] description = "Incorrect indentation type found in file." edit_indent = context.segment.raw.replace(wrong_indent, correct_indent) # Ensure that the number of space indents is a multiple of tab_space_size # before attempting to convert spaces to tabs to avoid mixed indents # unless we are converted tabs to spaces (indent_unit = space) if ((self.indent_unit == "space" or context.segment.raw.count(space) % self.tab_space_size == 0) # Only attempt a fix at the start of a newline for now and (len(context.raw_stack) == 0 or context.raw_stack[-1].is_type("newline"))): fixes = [ LintFix.replace( context.segment, [ WhitespaceSegment(raw=edit_indent), ], ) ] elif not (len(context.raw_stack) == 0 or context.raw_stack[-1].is_type("newline")): # give a helpful message if the wrong indent has been found and is not # at the start of a newline description += ( " The indent occurs after other text, so a manual fix is needed." ) else: # If we get here, the indent_unit is tabs, and the number of spaces is # not a multiple of tab_space_size description += " The number of spaces is not a multiple of " "tab_space_size, so a manual fix is needed." return LintResult(anchor=context.segment, fixes=fixes, description=description) return LintResult()
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Looking for DISTINCT before a bracket. Look for DISTINCT keyword immediately followed by open parenthesis. """ # We trigger on `select_clause` and look for `select_clause_modifier` if context.segment.is_type("select_clause"): children = context.functional.segment.children() modifier = children.select(sp.is_type("select_clause_modifier")) first_element = children.select( sp.is_type("select_clause_element")).first() if not modifier or not first_element: return None # is the first element only an expression with only brackets? expression = (first_element.children( sp.is_type("expression")).first() or first_element) bracketed = expression.children(sp.is_type("bracketed")).first() if bracketed: fixes = [] # If there's nothing else in the expression, remove the brackets. if len(expression[0].segments) == 1: # Remove the brackets and strip any meta segments. fixes.append( LintFix.replace( bracketed[0], self.filter_meta(bracketed[0].segments)[1:-1]), ) # If no whitespace between DISTINCT and expression, add it. if not children.select(sp.is_whitespace(), start_seg=modifier[0], stop_seg=first_element[0]): fixes.append( LintFix.create_before( first_element[0], [WhitespaceSegment()], )) # If no fixes, no problem. if fixes: return LintResult(anchor=modifier[0], fixes=fixes) return None
def _coerce_indent_to( self, desired_indent: str, current_indent_buffer: Tuple[RawSegment, ...], current_anchor: BaseSegment, ) -> List[LintFix]: """Generate fixes to make an indent a certain size.""" # If there shouldn't be an indent at all, just delete. if len(desired_indent) == 0: fixes = [LintFix.delete(elem) for elem in current_indent_buffer] # If we don't have any indent and we should, then add a single elif len("".join(elem.raw for elem in current_indent_buffer)) == 0: fixes = [ LintFix.create_before( current_anchor, [ WhitespaceSegment( raw=desired_indent, ), ], ), ] # Otherwise edit the first element to be the right size else: # Edit the first element of this line's indent and remove any other # indents. fixes = [ LintFix.replace( current_indent_buffer[0], [ WhitespaceSegment( raw=desired_indent, ), ], ), ] + [LintFix.delete(elem) for elem in current_indent_buffer[1:]] return fixes
def _handle_semicolon_same_line( context: RuleContext, info: SegmentMoveContext) -> Optional[LintResult]: if not info.before_segment: return None # If preceding segments are found then delete the old # semi-colon and its preceding whitespace and then insert # the semi-colon in the correct location. fixes = [ LintFix.replace( info.anchor_segment, [ info.anchor_segment, SymbolSegment(raw=";", type="symbol", name="semicolon"), ], ), LintFix.delete(context.segment, ), ] fixes.extend(LintFix.delete(d) for d in info.whitespace_deletions) return LintResult( anchor=info.anchor_segment, fixes=fixes, )
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Top-level statements should not be wrapped in brackets.""" # We only care about bracketed segements that are direct # descendants of a top-level statement segment. if not (context.segment.is_type("bracketed") and [ segment.type for segment in context.parent_stack if segment.type != "batch" ] == ["file", "statement"]): return None # Replace the bracketed segment with it's # children, excluding the bracket symbols. bracket_set = {"start_bracket", "end_bracket"} fixes = [ LintFix.replace( context.segment, [ segment for segment in context.segment.segments if segment.name not in bracket_set and not segment.is_meta ], ) ] return LintResult(anchor=context.segment, fixes=fixes)
def _ensure_final_semicolon(self, context: RuleContext) -> Optional[LintResult]: # Locate the end of the file. if not self.is_final_segment(context): return None # Include current segment for complete stack. complete_stack: List[BaseSegment] = list(context.raw_stack) complete_stack.append(context.segment) # Iterate backwards over complete stack to find # if the final semi-colon is already present. anchor_segment = context.segment semi_colon_exist_flag = False is_one_line = False before_segment = [] for segment in complete_stack[::-1]: if segment.name == "semicolon": semi_colon_exist_flag = True elif segment.is_code: is_one_line = self._is_one_line_statement(context, segment) break elif not segment.is_meta: before_segment.append(segment) anchor_segment = segment semicolon_newline = self.multiline_newline if not is_one_line else False if not semi_colon_exist_flag: # Create the final semi-colon if it does not yet exist. # Semi-colon on same line. if not semicolon_newline: fixes = [ LintFix.replace( anchor_segment, [ anchor_segment, SymbolSegment( raw=";", type="symbol", name="semicolon"), ], ) ] # Semi-colon on new line. else: # Adjust before_segment and anchor_segment for inline # comments. ( before_segment, anchor_segment, ) = self._handle_preceding_inline_comments( before_segment, anchor_segment) fixes = [ LintFix.replace( anchor_segment, [ anchor_segment, NewlineSegment(), SymbolSegment( raw=";", type="symbol", name="semicolon"), ], ) ] return LintResult( anchor=anchor_segment, fixes=fixes, ) return None
def _eval(self, context: RuleContext) -> Optional[List[LintResult]]: """Relational operators should not be used to check for NULL values.""" # Context/motivation for this rule: # https://news.ycombinator.com/item?id=28772289 # https://stackoverflow.com/questions/9581745/sql-is-null-and-null if len(context.segment.segments) <= 2: return None # Allow assignments in SET clauses if context.parent_stack and context.parent_stack[-1].is_type( "set_clause_list", "execute_script_statement"): return None # Allow assignments in EXEC clauses if context.segment.is_type("set_clause_list", "execute_script_statement"): return None segment = context.functional.segment # Iterate through children of this segment looking for equals or "not # equals". Once found, check if the next code segment is a NULL literal. children = segment.children() operators = segment.children(sp.is_name("equals", "not_equal_to")) if len(operators) == 0: return None results: List[LintResult] = [] # We may have many operators for operator in operators: after_op_list = children.select(start_seg=operator) null_literal = after_op_list.first(sp.is_code()) # if the next bit of code isnt a NULL then we are good if not null_literal.all(sp.is_name("null_literal")): continue sub_seg = null_literal.get() assert sub_seg, "TypeGaurd: Segement Must exist Must exist" self.logger.debug( "Found NULL literal following equals/not equals @%s: %r", sub_seg.pos_marker, sub_seg.raw, ) edit = _create_base_is_null_sequence( is_upper=sub_seg.raw[0] == "N", operator_name=operator.name, ) prev_seg = after_op_list.first().get() next_seg = children.select(stop_seg=operator).last().get() if self._missing_whitespace(prev_seg, before=True): whitespace_segment: CorrectionListType = [WhitespaceSegment()] edit = whitespace_segment + edit if self._missing_whitespace(next_seg, before=False): edit = edit + [WhitespaceSegment()] res = LintResult( anchor=operator, fixes=[LintFix.replace( operator, edit, )], ) results.append(res) return results or None
def _eval(self, context: RuleContext) -> Optional[List[LintResult]]: """Unnecessary whitespace.""" # For the given segment, lint whitespace directly within it. prev_newline = True prev_whitespace = None violations = [] for seg in context.segment.segments: if seg.is_meta: continue elif seg.is_type("newline"): prev_newline = True prev_whitespace = None elif seg.is_type("whitespace"): # This is to avoid indents if not prev_newline: prev_whitespace = seg # We won't set prev_newline to False, just for whitespace # in case there's multiple indents, inserted by other rule # fixes (see #1713) elif seg.is_type("comment"): prev_newline = False prev_whitespace = None else: if prev_whitespace: if prev_whitespace.raw != " ": violations.append( LintResult( anchor=prev_whitespace, fixes=[ LintFix.replace( prev_whitespace, [WhitespaceSegment()], ) ], )) prev_newline = False prev_whitespace = None if seg.is_type("object_reference"): for child_seg in seg.get_raw_segments(): if child_seg.is_whitespace: violations.append( LintResult( anchor=child_seg, fixes=[LintFix.delete(child_seg)], )) if seg.is_type("comparison_operator"): delete_fixes = [ LintFix.delete(s) for s in seg.get_raw_segments() if s.is_whitespace ] if delete_fixes: violations.append( LintResult( anchor=delete_fixes[0].anchor, fixes=delete_fixes, )) if context.segment.is_type("casting_operator"): leading_whitespace_segments = ( context.functional.raw_stack.reversed().select( select_if=sp.is_whitespace(), loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()), )) trailing_whitespace_segments = ( context.functional.siblings_post.raw_segments.select( select_if=sp.is_whitespace(), loop_while=sp.or_(sp.is_whitespace(), sp.is_meta()), )) fixes: List[LintFix] = [] fixes.extend( LintFix.delete(s) for s in leading_whitespace_segments) fixes.extend( LintFix.delete(s) for s in trailing_whitespace_segments) if fixes: violations.append( LintResult( anchor=context.segment, fixes=fixes, )) return violations or None
def _eval(self, context: RuleContext) -> Optional[List[LintResult]]: self.violation_buff = [] self.violation_exists = False # Bands of select targets in order to be enforced select_element_order_preference = ( ("wildcard_expression",), ( "object_reference", "literal", "cast_expression", ("function", "cast"), ("expression", "cast_expression"), ), ) # Track which bands have been seen, with additional empty list for the # non-matching elements. If we find a matching target element, we append the # element to the corresponding index. self.seen_band_elements: List[List[BaseSegment]] = [ [] for _ in select_element_order_preference ] + [ [] ] # type: ignore if context.segment.is_type("select_clause"): # Ignore select clauses which belong to: # - set expression, which is most commonly a union # - insert_statement # - create table statement # # In each of these contexts, the order of columns in a select should # be preserved. if len(context.parent_stack) >= 2 and context.parent_stack[-2].is_type( "insert_statement", "set_expression" ): return None if len(context.parent_stack) >= 3 and context.parent_stack[-3].is_type( "create_table_statement" ): return None select_clause_segment = context.segment select_target_elements = context.segment.get_children( "select_clause_element" ) if not select_target_elements: return None # Iterate through all the select targets to find any order violations for segment in select_target_elements: # The band index of the current segment in # select_element_order_preference self.current_element_band = None # Compare the segment to the bands in select_element_order_preference for i, band in enumerate(select_element_order_preference): for e in band: # Identify simple select target if segment.get_child(e): self._validate(i, segment) # Identify function elif type(e) == tuple and e[0] == "function": try: if ( segment.get_child("function") .get_child("function_name") .raw == e[1] ): self._validate(i, segment) except AttributeError: # If the segment doesn't match pass # Identify simple expression elif type(e) == tuple and e[0] == "expression": try: if ( segment.get_child("expression").get_child(e[1]) and segment.get_child("expression").segments[0].type in ( "column_reference", "object_reference", "literal", ) # len == 2 to ensure the expression is 'simple' and len(segment.get_child("expression").segments) == 2 ): self._validate(i, segment) except AttributeError: # If the segment doesn't match pass # If the target doesn't exist in select_element_order_preference then it # is 'complex' and must go last if self.current_element_band is None: self.seen_band_elements[-1].append(segment) if self.violation_exists: # Create a list of all the edit fixes # We have to do this at the end of iterating through all the # select_target_elements to get the order correct. This means we can't # add a lint fix to each individual LintResult as we go ordered_select_target_elements = [ segment for band in self.seen_band_elements for segment in band ] # TODO: The "if" in the loop below compares corresponding items # to avoid creating "do-nothing" edits. A potentially better # approach would leverage difflib.SequenceMatcher.get_opcodes(), # which generates a list of edit actions (similar to the # command-line "diff" tool in Linux). This is more complex to # implement, but minimizing the number of LintFixes makes the # final application of patches (in "sqlfluff fix") more robust. fixes = [ LintFix.replace( initial_select_target_element, [replace_select_target_element], ) for initial_select_target_element, replace_select_target_element in zip( # noqa: E501 select_target_elements, ordered_select_target_elements ) if initial_select_target_element is not replace_select_target_element ] # Anchoring on the select statement segment ensures that # select statements which include macro targets are ignored # when ignore_templated_areas is set lint_result = LintResult(anchor=select_clause_segment, fixes=fixes) self.violation_buff = [lint_result] return self.violation_buff or None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Unnecessary quoted identifier.""" # Config type hints self.prefer_quoted_identifiers: bool self.ignore_words: str # Ignore some segment types if context.functional.parent_stack.any( sp.is_type(*self._ignore_types)): return None if self.prefer_quoted_identifiers: context_policy = "naked_identifier" identifier_contents = context.segment.raw else: context_policy = "quoted_identifier" identifier_contents = context.segment.raw[1:-1] # Get the ignore_words_list configuration. try: ignore_words_list = self.ignore_words_list except AttributeError: # First-time only, read the settings from configuration. This is # very slow. ignore_words_list = self._init_ignore_words_list() # Skip if in ignore list if ignore_words_list and identifier_contents.lower( ) in ignore_words_list: return None # Ignore the segments that are not of the same type as the defined policy above. # Also TSQL has a keyword called QUOTED_IDENTIFIER which maps to the name so # need to explicity check for that. if (context.segment.name not in context_policy or context.segment.raw.lower() in ("quoted_identifier", "naked_identifier")): return None # Manage cases of identifiers must be quoted first. # Naked identifiers are _de facto_ making this rule fail as configuration forces # them to be quoted. # In this case, it cannot be fixed as which quote to use is dialect dependent if self.prefer_quoted_identifiers: return LintResult( context.segment, description=f"Missing quoted identifier {identifier_contents}.", ) # Now we only deal with NOT forced quoted identifiers configuration # (meaning prefer_quoted_identifiers=False). # Extract contents of outer quotes. quoted_identifier_contents = context.segment.raw[1:-1] # Retrieve NakedIdentifierSegment RegexParser for the dialect. naked_identifier_parser = context.dialect._library[ "NakedIdentifierSegment"] # Check if quoted_identifier_contents could be a valid naked identifier # and that it is not a reserved keyword. if (regex.fullmatch( naked_identifier_parser.template, quoted_identifier_contents, regex.IGNORECASE, ) is not None) and (regex.fullmatch( naked_identifier_parser.anti_template, quoted_identifier_contents, regex.IGNORECASE, ) is None): return LintResult( context.segment, fixes=[ LintFix.replace( context.segment, [ CodeSegment( raw=quoted_identifier_contents, name="naked_identifier", type="identifier", ) ], ) ], description= f"Unnecessary quoted identifier {context.segment.raw}.", ) return None
def _eval_single_select_target_element(self, select_targets_info, context: RuleContext): select_clause = context.functional.segment parent_stack = context.parent_stack wildcards = select_clause.children( sp.is_type("select_clause_element")).children( sp.is_type("wildcard_expression")) is_wildcard = bool(wildcards) if is_wildcard: wildcard_select_clause_element = wildcards[0] if (select_targets_info.select_idx < select_targets_info.first_new_line_idx < select_targets_info.first_select_target_idx) and ( not is_wildcard): # Do we have a modifier? select_children = select_clause.children() modifier: Optional[Segments] modifier = select_children.first( sp.is_type("select_clause_modifier")) # Prepare the select clause which will be inserted insert_buff = [ WhitespaceSegment(), select_children[select_targets_info.first_select_target_idx], ] # Check if the modifier is one we care about if modifier: # If it's already on the first line, ignore it. if (select_children.index(modifier.get()) < select_targets_info.first_new_line_idx): modifier = None fixes = [ # Delete the first select target from its original location. # We'll add it to the right section at the end, once we know # what to add. LintFix.delete( select_children[ select_targets_info.first_select_target_idx], ), ] # If we have a modifier to move: if modifier: # Add it to the insert insert_buff = [WhitespaceSegment(), modifier[0]] + insert_buff modifier_idx = select_children.index(modifier.get()) # Delete the whitespace after it (which is two after, thanks to indent) if (len(select_children) > modifier_idx + 1 and select_children[modifier_idx + 2].is_whitespace): fixes += [ LintFix.delete(select_children[modifier_idx + 2], ), ] # Delete the modifier itself fixes += [ LintFix.delete(modifier[0], ), ] # Set the position marker for removing the preceding # whitespace and newline, which we'll use below. start_idx = modifier_idx else: # Set the position marker for removing the preceding # whitespace and newline, which we'll use below. start_idx = select_targets_info.first_select_target_idx if parent_stack and parent_stack[-1].is_type("select_statement"): select_stmt = parent_stack[-1] select_clause_idx = select_stmt.segments.index( select_clause.get()) after_select_clause_idx = select_clause_idx + 1 if len(select_stmt.segments) > after_select_clause_idx: def _fixes_for_move_after_select_clause( stop_seg: BaseSegment, add_newline: bool = True, ) -> List[LintFix]: """Cleans up by moving leftover select_clause segments. Context: Some of the other fixes we make in _eval_single_select_target_element() leave the select_clause in an illegal state -- a select_clause's *rightmost children cannot be whitespace or comments*. This function addresses that by moving these segments up the parse tree to an ancestor segment chosen by _choose_anchor_segment(). After these fixes are applied, these segments may, for example, be *siblings* of select_clause. """ start_seg = ( modifier[0] if modifier else select_children[ select_targets_info.first_new_line_idx]) move_after_select_clause = select_children.select( start_seg=start_seg, stop_seg=stop_seg, ) fixes = [ LintFix.delete(seg) for seg in move_after_select_clause ] fixes.append( LintFix.create_after( self._choose_anchor_segment( context, "create_after", select_clause[0], filter_meta=True, ), ([NewlineSegment()] if add_newline else []) + list(move_after_select_clause), )) return fixes if select_stmt.segments[after_select_clause_idx].is_type( "newline"): # Since we're deleting the newline, we should also delete all # whitespace before it or it will add random whitespace to # following statements. So walk back through the segment # deleting whitespace until you get the previous newline, or # something else. to_delete = select_children.reversed().select( loop_while=sp.is_type("whitespace"), start_seg=select_children[start_idx], ) fixes += [LintFix.delete(seg) for seg in to_delete] # The select_clause is immediately followed by a # newline. Delete the newline in order to avoid leaving # behind an empty line after fix, *unless* we stopped # due to something other than a newline. delete_last_newline = select_children[ start_idx - len(to_delete) - 1].is_type("newline") # Delete the newline if we decided to. if delete_last_newline: fixes.append( LintFix.delete( select_stmt. segments[after_select_clause_idx], )) fixes += _fixes_for_move_after_select_clause( to_delete[-1], ) elif select_stmt.segments[after_select_clause_idx].is_type( "whitespace"): # The select_clause has stuff after (most likely a comment) # Delete the whitespace immediately after the select clause # so the other stuff aligns nicely based on where the select # clause started fixes += [ LintFix.delete( select_stmt.segments[after_select_clause_idx], ), ] fixes += _fixes_for_move_after_select_clause( select_children[ select_targets_info.first_select_target_idx], ) elif select_stmt.segments[after_select_clause_idx].is_type( "dedent"): # Again let's strip back the whitespace, but simpler # as don't need to worry about new line so just break # if see non-whitespace to_delete = select_children.reversed().select( loop_while=sp.is_type("whitespace"), start_seg=select_children[select_clause_idx - 1], ) fixes += [LintFix.delete(seg) for seg in to_delete] if to_delete: fixes += _fixes_for_move_after_select_clause( to_delete[-1], # If we deleted a newline, create a newline. any(seg for seg in to_delete if seg.is_type("newline")), ) else: fixes += _fixes_for_move_after_select_clause( select_children[ select_targets_info.first_select_target_idx], ) fixes += [ # Insert the select_clause in place of the first newline in the # Select statement LintFix.replace( select_children[select_targets_info.first_new_line_idx], insert_buff, ), ] return LintResult( anchor=select_clause.get(), fixes=fixes, ) # If we have a wildcard on the same line as the FROM keyword, but not the same # line as the SELECT keyword, we need to move the FROM keyword to its own line. # i.e. # SELECT # * FROM foo if select_targets_info.from_segment: if (is_wildcard and (select_clause[0].pos_marker.working_line_no != select_targets_info.from_segment.pos_marker.working_line_no) and (wildcard_select_clause_element.pos_marker.working_line_no == select_targets_info.from_segment.pos_marker.working_line_no)): fixes = [ LintFix.delete(ws) for ws in select_targets_info.pre_from_whitespace ] fixes.append( LintFix.create_before( select_targets_info.from_segment, [NewlineSegment()], )) return LintResult(anchor=select_clause.get(), fixes=fixes) return None
def _eval(self, context: RuleContext) -> Optional[LintResult]: """Find rule violations and provide fixes.""" # Config type hints self.prefer_count_0: bool self.prefer_count_1: bool if (context.segment.is_type("function") and context.segment.get_child("function_name").raw_upper == "COUNT"): # Get bracketed content f_content = context.functional.segment.children( sp.is_type("bracketed")).children( sp.and_( sp.not_(sp.is_meta()), sp.not_( sp.is_type("start_bracket", "end_bracket", "whitespace", "newline")), )) if len(f_content) != 1: # pragma: no cover return None preferred = "*" if self.prefer_count_1: preferred = "1" elif self.prefer_count_0: preferred = "0" if f_content[0].is_type("star") and (self.prefer_count_1 or self.prefer_count_0): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( f_content[0], [ f_content[0].edit(f_content[0].raw.replace( "*", preferred)) ], ), ], ) if f_content[0].is_type("expression"): expression_content = [ seg for seg in f_content[0].segments if not seg.is_meta ] if (len(expression_content) == 1 and expression_content[0].is_type("literal") and expression_content[0].raw in ["0", "1"] and expression_content[0].raw != preferred): return LintResult( anchor=context.segment, fixes=[ LintFix.replace( expression_content[0], [ expression_content[0].edit( expression_content[0].raw.replace( expression_content[0].raw, preferred)), ], ), ], ) return None
def _lint_aliases_in_join(self, base_table, from_expression_elements, column_reference_segments, segment): """Lint and fix all aliases in joins - except for self-joins.""" # A buffer to keep any violations. violation_buff = [] to_check = list( self._filter_table_expressions(base_table, from_expression_elements)) # How many times does each table appear in the FROM clause? table_counts = Counter(ai.table_ref.raw for ai in to_check) # What is the set of aliases used for each table? (We are mainly # interested in the NUMBER of different aliases used.) table_aliases = defaultdict(set) for ai in to_check: if ai and ai.table_ref and ai.alias_identifier_ref: table_aliases[ai.table_ref.raw].add( ai.alias_identifier_ref.raw) # For each aliased table, check whether to keep or remove it. for alias_info in to_check: # If the same table appears more than once in the FROM clause with # different alias names, do not consider removing its aliases. # The aliases may have been introduced simply to make each # occurrence of the table independent within the query. if (table_counts[alias_info.table_ref.raw] > 1 and len(table_aliases[alias_info.table_ref.raw]) > 1): continue select_clause = segment.get_child("select_clause") ids_refs = [] # Find all references to alias in select clause if alias_info.alias_identifier_ref: alias_name = alias_info.alias_identifier_ref.raw for alias_with_column in select_clause.recursive_crawl( "object_reference"): used_alias_ref = alias_with_column.get_child("identifier") if used_alias_ref and used_alias_ref.raw == alias_name: ids_refs.append(used_alias_ref) # Find all references to alias in column references for exp_ref in column_reference_segments: used_alias_ref = exp_ref.get_child("identifier") # exp_ref.get_child('dot') ensures that the column reference includes a # table reference if (used_alias_ref and used_alias_ref.raw == alias_name and exp_ref.get_child("dot")): ids_refs.append(used_alias_ref) # Fixes for deleting ` as sth` and for editing references to aliased tables # Note unparsable errors have cause the delete to fail (see #2484) # so check there is a d before doing deletes. fixes = [ *[ LintFix.delete(d) for d in [alias_info.alias_exp_ref, alias_info.whitespace_ref] if d ], *[ LintFix.replace( alias, [alias.edit(alias_info.table_ref.raw)], source=[alias_info.table_ref], ) for alias in [alias_info.alias_identifier_ref, *ids_refs] if alias ], ] violation_buff.append( LintResult( anchor=alias_info.alias_identifier_ref, description= "Avoid aliases in from clauses and join conditions.", fixes=fixes, )) return violation_buff or None