class RequireDescriptiveNameRule(CstLintRule): VALID = [ Valid(""" class DescriptiveName: pass """), Valid(""" class ThisClass: def this_method(self): pass """), Valid(""" def descriptive_function(): pass """), Valid(""" def function(descriptive, parameter): pass """), ] INVALID = [ Invalid(""" class T: pass """), Invalid(""" class ThisClass: def m(self): pass """), Invalid(""" def f(): pass """), Invalid(""" def fun(a): pass """), ] def visit_ClassDef(self, node: cst.ClassDef) -> None: self._validate_name_length(node, "class") def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self._validate_name_length(node, "function") def visit_Param(self, node: cst.Param) -> None: self._validate_name_length(node, "parameter") def _validate_name_length(self, node: Union[cst.ClassDef, cst.FunctionDef, cst.Param], nodetype: str) -> None: nodename = node.name.value if len(nodename) == 1: self.report(node, message=MESSAGE.format(nodetype=nodetype, nodename=nodename))
class NoRedundantLambdaRule(CstLintRule): """A lamba function which has a single objective of passing all it is arguments to another callable can be safely replaced by that callable.""" VALID = [ Valid("lambda x: foo(y)"), Valid("lambda x: foo(x, y)"), Valid("lambda x, y: foo(x)"), Valid("lambda *, x: foo(x)"), Valid("lambda x = y: foo(x)"), Valid("lambda x, y: foo(y, x)"), Valid("lambda self: self.func()"), Valid("lambda x, y: foo(y=x, x=y)"), Valid("lambda x, y, *z: foo(x, y, z)"), Valid("lambda x, y, **z: foo(x, y, z)"), ] INVALID = [ Invalid("lambda: self.func()", expected_replacement="self.func"), Invalid("lambda x: foo(x)", expected_replacement="foo"), Invalid( "lambda x, y, z: (t + u).math_call(x, y, z)", expected_replacement="(t + u).math_call", ), ] @staticmethod def _is_simple_parameter_spec(node: cst.Parameters) -> bool: if (node.star_kwarg is not None or len(node.kwonly_params) > 0 or len(node.posonly_params) > 0 or not isinstance(node.star_arg, cst.MaybeSentinel)): return False return all(param.default is None for param in node.params) def visit_Lambda(self, node: cst.Lambda) -> None: if m.matches( node, m.Lambda( params=m.MatchIfTrue(self._is_simple_parameter_spec), body=m.Call(args=[ m.Arg(value=m.Name(value=param.name.value), star="", keyword=None) for param in node.params.params ]), ), ): call = cst.ensure_type(node.body, cst.Call) full_name = get_full_name_for_node(call) if full_name is None: full_name = "function" self.report( node, UNNECESSARY_LAMBDA.format(function=full_name), replacement=call.func, )
class NoRedundantFStringRule(CstLintRule): """ Remove redundant f-string without placeholders. """ MESSAGE: str = "f-string doesn't have placeholders, remove redundant f-string." VALID = [ Valid('good: str = "good"'), Valid('good: str = f"with_arg{arg}"'), Valid('good = "good{arg1}".format(1234)'), Valid('good = "good".format()'), Valid('good = "good" % {}'), Valid('good = "good" % ()'), Valid('good = rf"good\t+{bar}"'), ] INVALID = [ Invalid( 'bad: str = f"bad" + "bad"', line=1, expected_replacement='bad: str = "bad" + "bad"', ), Invalid( "bad: str = f'bad'", line=1, expected_replacement="bad: str = 'bad'", ), Invalid( "bad: str = rf'bad\t+'", line=1, expected_replacement="bad: str = r'bad\t+'", ), Invalid( 'bad: str = f"no args but messing up {{ braces }}"', line=1, expected_replacement= 'bad: str = "no args but messing up { braces }"', ), ] def visit_FormattedString(self, node: cst.FormattedString) -> None: if not m.matches(node, m.FormattedString(parts=(m.FormattedStringText(), ))): return old_string_inner = cst.ensure_type(node.parts[0], cst.FormattedStringText).value if "{{" in old_string_inner or "}}" in old_string_inner: old_string_inner = old_string_inner.replace("{{", "{").replace( "}}", "}") new_string_literal = (node.start.replace("f", "").replace("F", "") + old_string_inner + node.end) self.report(node, replacement=cst.SimpleString(new_string_literal))
class NoInheritFromObjectRule(CstLintRule): """ In Python 3, a class is inherited from ``object`` by default. Explicitly inheriting from ``object`` is redundant, so removing it keeps the code simpler. """ MESSAGE = "Inheriting from object is a no-op. 'class Foo:' is just fine =)" VALID = [ Valid("class A(something): pass"), Valid( """ class A: pass""" ), ] INVALID = [ Invalid( """ class B(object): pass""", line=1, column=1, expected_replacement=""" class B: pass""", ), Invalid( """ class B(object, A): pass""", line=1, column=1, expected_replacement=""" class B(A): pass""", ), ] def visit_ClassDef(self, node: cst.ClassDef) -> None: new_bases = tuple( base for base in node.bases if not m.matches(base.value, m.Name("object")) ) if tuple(node.bases) != new_bases: # reconstruct classdef, removing parens if bases and keywords are empty new_classdef = node.with_changes( bases=new_bases, lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, ) # report warning and autofix self.report(node, replacement=new_classdef)
class NoRedundantListComprehensionRule(CstLintRule): """ A derivative of flake8-comprehensions's C407 rule. """ VALID = [ Valid("any(val for val in iterable)"), Valid("all(val for val in iterable)"), # C407 would complain about these, but we won't Valid("frozenset([val for val in iterable])"), Valid("max([val for val in iterable])"), Valid("min([val for val in iterable])"), Valid("sorted([val for val in iterable])"), Valid("sum([val for val in iterable])"), Valid("tuple([val for val in iterable])"), ] INVALID = [ Invalid( "any([val for val in iterable])", expected_replacement="any(val for val in iterable)", ), Invalid( "all([val for val in iterable])", expected_replacement="all(val for val in iterable)", ), ] def visit_Call(self, node: cst.Call) -> None: # This set excludes frozenset, max, min, sorted, sum, and tuple, which C407 would warn # about, because none of those functions short-circuit. if m.matches( node, m.Call(func=m.Name("all") | m.Name("any"), args=[m.Arg(value=m.ListComp())]), ): list_comp = cst.ensure_type(node.args[0].value, cst.ListComp) self.report( node, UNNECESSARY_LIST_COMPREHENSION.format( func=cst.ensure_type(node.func, cst.Name).value), replacement=node.deep_replace( list_comp, cst.GeneratorExp(elt=list_comp.elt, for_in=list_comp.for_in, lpar=[], rpar=[]), ), )
class NoAssertEqualsRule(CstLintRule): """ Discourages use of ``assertEquals`` as it is deprecated (see https://docs.python.org/2/library/unittest.html#deprecated-aliases and https://bugs.python.org/issue9424). Use the standardized ``assertEqual`` instead. """ MESSAGE: str = ( '"assertEquals" is deprecated, use "assertEqual" instead.\n' + "See https://docs.python.org/2/library/unittest.html#deprecated-aliases and https://bugs.python.org/issue9424." ) VALID = [Valid("self.assertEqual(a, b)")] INVALID = [ Invalid( "self.assertEquals(a, b)", expected_replacement="self.assertEqual(a, b)", ) ] def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call(func=m.Attribute(value=m.Name("self"), attr=m.Name("assertEquals"))), ): new_call = node.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertEqual", ) self.report(node, replacement=new_call)
class UseLintFixmeCommentRule(CstLintRule): """ To silence a lint warning, use ``lint-fixme`` (when plans to fix the issue later) or ``lint-ignore`` (when the lint warning is not valid) comments. The comment requires to be in a standalone comment line and follows the format ``lint-fixme: RULE_NAMES EXTRA_COMMENTS``. It suppresses the lint warning with the RULE_NAMES in the next line. RULE_NAMES can be one or more lint rule names separated by comma. ``noqa`` is deprecated and not supported because explicitly providing lint rule names to be suppressed in lint-fixme comment is preferred over implicit noqa comments. Implicit noqa suppression comments sometimes accidentally silence warnings unexpectedly. """ MESSAGE: str = "noqa is deprecated. Use `lint-fixme` or `lint-ignore` instead." VALID = [ Valid(""" # lint-fixme: UseFstringRule "%s" % "hi" """), Valid(""" # lint-ignore: UsePlusForStringConcatRule 'ab' 'cd' """), ] INVALID = [ Invalid("fn() # noqa"), Invalid(""" ( 1, 2, # noqa ) """), Invalid(""" class C: # noqa ... """), ] def visit_Comment(self, node: cst.Comment) -> None: target = "# noqa" if node.value[:len(target)].lower() == target: self.report(node)
class UseFstringRule(CstLintRule): MESSAGE: str = ( "As mentioned in the [Contributing Guidelines]" + "(https://github.com/TheAlgorithms/Python/blob/master/CONTRIBUTING.md), " + "please do not use printf style formatting or `str.format()`. " + "Use [f-string](https://realpython.com/python-f-strings/) instead to be " + "more readable and efficient.") VALID = [ Valid("assigned='string'; f'testing {assigned}'"), Valid("'simple string'"), Valid("'concatenated' + 'string'"), Valid("b'bytes %s' % 'string'.encode('utf-8')"), ] INVALID = [ Invalid("'hello, {name}'.format(name='you')"), Invalid("'hello, %s' % 'you'"), Invalid("r'raw string value=%s' % val"), ] def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call(func=m.Attribute(value=m.SimpleString(), attr=m.Name(value="format"))), ): self.report(node) def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: if (m.matches( node, m.BinaryOperation(left=m.SimpleString(), operator=m.Modulo())) # SimpleString can be bytes and fstring don't support bytes. # https://www.python.org/dev/peps/pep-0498/#no-binary-f-strings and isinstance( cst.ensure_type(node.left, cst.SimpleString).evaluated_value, str)): self.report(node)
class AwaitAsyncCallRule(CstLintRule): """ Enforces calls to coroutines are preceeded by the ``await`` keyword. Awaiting on a coroutine will execute it while simply calling a coroutine returns a coroutine object (https://docs.python.org/3/library/asyncio-task.html#coroutines). """ MESSAGE: str = ( "Async function call will only be executed with `await` statement. Did you forget to add `await`? " + "If you intend to not await, please add comment to disable this warning: # lint-fixme: AwaitAsyncCallRule " ) METADATA_DEPENDENCIES = (TypeInferenceProvider, ) VALID = [ Valid(""" async def async_func(): await async_foo() """), Valid(""" def foo(): pass foo() """), Valid(""" async def foo(): pass async def bar(): await foo() """), Valid(""" async def foo(): pass async def bar(): x = await foo() """), Valid(""" async def foo() -> bool: pass async def bar(): while not await foo(): pass """), Valid(""" import asyncio async def foo(): pass asyncio.run(foo()) """), ] INVALID = [ Invalid( """ async def foo(): pass async def bar(): foo() """, expected_replacement=""" async def foo(): pass async def bar(): await foo() """, ), Invalid( """ class Foo: async def _attr(self): pass obj = Foo() obj._attr """, expected_replacement=""" class Foo: async def _attr(self): pass obj = Foo() await obj._attr """, ), Invalid( """ class Foo: async def _method(self): pass obj = Foo() obj._method() """, expected_replacement=""" class Foo: async def _method(self): pass obj = Foo() await obj._method() """, ), Invalid( """ class Foo: async def _method(self): pass obj = Foo() result = obj._method() """, expected_replacement=""" class Foo: async def _method(self): pass obj = Foo() result = await obj._method() """, ), Invalid( """ class Foo: async def bar(): pass class NodeUser: async def get(): do_stuff() return Foo() user = NodeUser.get().bar() """, expected_replacement=""" class Foo: async def bar(): pass class NodeUser: async def get(): do_stuff() return Foo() user = await NodeUser.get().bar() """, ), Invalid( """ class Foo: async def _attr(self): pass obj = Foo() attribute = obj._attr """, expected_replacement=""" class Foo: async def _attr(self): pass obj = Foo() attribute = await obj._attr """, ), Invalid( code=""" async def foo() -> bool: pass x = True if x and foo(): pass """, expected_replacement=""" async def foo() -> bool: pass x = True if x and await foo(): pass """, ), Invalid( code=""" async def foo() -> bool: pass x = True are_both_true = x and foo() """, expected_replacement=""" async def foo() -> bool: pass x = True are_both_true = x and await foo() """, ), Invalid( """ async def foo() -> bool: pass if foo(): do_stuff() """, expected_replacement=""" async def foo() -> bool: pass if await foo(): do_stuff() """, ), Invalid( """ async def foo() -> bool: pass if not foo(): do_stuff() """, expected_replacement=""" async def foo() -> bool: pass if not await foo(): do_stuff() """, ), Invalid( """ class Foo: async def _attr(self): pass def bar(self): if self._attr: pass """, expected_replacement=""" class Foo: async def _attr(self): pass def bar(self): if await self._attr: pass """, ), Invalid( """ class Foo: async def _attr(self): pass def bar(self): if not self._attr: pass """, expected_replacement=""" class Foo: async def _attr(self): pass def bar(self): if not await self._attr: pass """, ), # Case where only cst.Attribute node's `attr` returns awaitable Invalid( """ class Foo: async def _attr(self): pass def bar() -> Foo: return Foo() attribute = bar()._attr """, expected_replacement=""" class Foo: async def _attr(self): pass def bar() -> Foo: return Foo() attribute = await bar()._attr """, ), # Case where only cst.Attribute node's `value` returns awaitable Invalid( """ class Foo: def _attr(self): pass async def bar(): await do_stuff() return Foo() attribute = bar()._attr """, expected_replacement=""" class Foo: def _attr(self): pass async def bar(): await do_stuff() return Foo() attribute = await bar()._attr """, ), Invalid( """ async def bar() -> bool: pass while bar(): pass """, expected_replacement=""" async def bar() -> bool: pass while await bar(): pass """, ), Invalid( """ async def bar() -> bool: pass while not bar(): pass """, expected_replacement=""" async def bar() -> bool: pass while not await bar(): pass """, ), Invalid( """ class Foo: @classmethod async def _method(cls): pass Foo._method() """, expected_replacement=""" class Foo: @classmethod async def _method(cls): pass await Foo._method() """, ), Invalid( """ class Foo: @staticmethod async def _method(self): pass Foo._method() """, expected_replacement=""" class Foo: @staticmethod async def _method(self): pass await Foo._method() """, ), ] @staticmethod def _is_awaitable_callable(annotation: str) -> bool: if not (annotation.startswith("typing.Callable") or annotation.startswith("typing.ClassMethod") or annotation.startswith("StaticMethod")): # Exit early if this is not even a `typing.Callable` annotation. return False try: # Wrap this in a try-except since the type annotation may not be parse-able as a module. # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`. parsed_ann = cst.parse_module(annotation) except Exception: return False # If passed annotation does not match the expected annotation structure for a `typing.Callable` with # typing.Coroutine as the return type, matched_callable_ann will simply be `None`. # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]] matched_callable_ann: Optional[Dict[str, Union[ Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract( parsed_ann, m.Module(body=[ m.SimpleStatementLine(body=[ m.Expr(value=m.Subscript(slice=[ m.SubscriptElement(), m.SubscriptElement(slice=m.Index(value=m.Subscript( value=m.SaveMatchedNode( m.Attribute(), "base_return_type", )))), ], )) ]), ]), ) if (matched_callable_ann is not None and "base_return_type" in matched_callable_ann): base_return_type = get_full_name_for_node( cst.ensure_type(matched_callable_ann["base_return_type"], cst.CSTNode)) return (base_return_type is not None and base_return_type == "typing.Coroutine") return False def _get_awaitable_replacement(self, node: cst.CSTNode) -> Optional[cst.CSTNode]: annotation = self.get_metadata(TypeInferenceProvider, node, None) if annotation is not None and ( annotation.startswith("typing.Coroutine") or self._is_awaitable_callable(annotation)): if isinstance(node, cst.BaseExpression): return cst.Await(expression=node) return None def _get_async_attr_replacement( self, node: cst.Attribute) -> Optional[cst.CSTNode]: value = node.value if m.matches(value, m.Call()): value = cast(cst.Call, value) value_replacement = self._get_async_call_replacement(value) if value_replacement is not None: return node.with_changes(value=value_replacement) return self._get_awaitable_replacement(node) def _get_async_call_replacement(self, node: cst.Call) -> Optional[cst.CSTNode]: func = node.func if m.matches(func, m.Attribute()): func = cast(cst.Attribute, func) attr_func_replacement = self._get_async_attr_replacement(func) if attr_func_replacement is not None: return node.with_changes(func=attr_func_replacement) return self._get_awaitable_replacement(node) def _get_async_expr_replacement( self, node: cst.CSTNode) -> Optional[cst.CSTNode]: if m.matches(node, m.Call()): node = cast(cst.Call, node) return self._get_async_call_replacement(node) elif m.matches(node, m.Attribute()): node = cast(cst.Attribute, node) return self._get_async_attr_replacement(node) elif m.matches(node, m.UnaryOperation(operator=m.Not())): node = cast(cst.UnaryOperation, node) replacement_expression = self._get_async_expr_replacement( node.expression) if replacement_expression is not None: return node.with_changes(expression=replacement_expression) elif m.matches(node, m.BooleanOperation()): node = cast(cst.BooleanOperation, node) maybe_left = self._get_async_expr_replacement(node.left) maybe_right = self._get_async_expr_replacement(node.right) if maybe_left is not None or maybe_right is not None: left_replacement = maybe_left if maybe_left is not None else node.left right_replacement = (maybe_right if maybe_right is not None else node.right) return node.with_changes(left=left_replacement, right=right_replacement) return None def _maybe_autofix_node(self, node: cst.CSTNode, attribute_name: str) -> None: replacement_value = self._get_async_expr_replacement( getattr(node, attribute_name)) if replacement_value is not None: replacement = node.with_changes( **{attribute_name: replacement_value}) self.report(node, replacement=replacement) def visit_If(self, node: cst.If) -> None: self._maybe_autofix_node(node, "test") def visit_While(self, node: cst.While) -> None: self._maybe_autofix_node(node, "test") def visit_Assign(self, node: cst.Assign) -> None: self._maybe_autofix_node(node, "value") def visit_Expr(self, node: cst.Expr) -> None: self._maybe_autofix_node(node, "value")
class RequireTypeHintRule(CstLintRule): VALID = [ Valid(""" def func() -> str: pass """), Valid(""" def func() -> None: pass """), Valid(""" def func(some: str, other: int) -> None: pass """), Valid(""" class Random: def random_method(self, value: int) -> None: pass """), Valid(""" class Random: @classmethod def initiate(cls, value: str) -> str: pass """), Valid(""" lambda ignore: ignore """), Valid(""" lambda closure: lambda inside: closure + inside """), ] INVALID = [ Invalid(""" def func(): pass """), Invalid(""" def func(num: int, val: str): pass """), Invalid(""" def func(num: int, val) -> None: pass """), Invalid(""" class Random: def __init__(self, val) -> None: pass """), Invalid(""" class Random: @classmethod def from_class(cls, val) -> None: pass """), Invalid(""" def spam() -> None: foo = lambda bar: str(bar) def wrapper(call) -> None: pass return wrapper(foo) """), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self._lambda_counter: int = 0 def visit_Lambda(self, node: cst.Lambda) -> None: self._lambda_counter += 1 def leave_Lambda(self, original_node: cst.Lambda) -> None: self._lambda_counter -= 1 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if node.returns is None: self.report( node, MISSING_RETURN_TYPE_HINT.format(nodename=node.name.value)) def visit_Param(self, node: cst.Param) -> None: # Annotating parameters in ``lambda`` is not possible. if self._lambda_counter == 0: nodename = node.name.value if node.annotation is None and nodename not in IGNORE_PARAM: self.report(node, MISSING_TYPE_HINT.format(nodename=nodename))
class UseIsNoneOnOptionalRule(CstLintRule): """ Enforces explicit use of ``is None`` or ``is not None`` when checking whether an Optional has a value. Directly testing an object (e.g. ``if x``) implicitely tests for a truth value which returns ``True`` unless the object's ``__bool__()`` method returns False, its ``__len__()`` method returns '0', or it is one of the constants ``None`` or ``False``. (https://docs.python.org/3.8/library/stdtypes.html#truth-value-testing). """ METADATA_DEPENDENCIES = (TypeInferenceProvider,) MESSAGE: str = ( "When checking if an `Optional` has a value, avoid using it as a boolean since it implicitly checks the object's `__bool__()`, `__len__()` is not `0`, or the value is not `None`. " + "Instead, use `is None` or `is not None` to be explicit." ) VALID: List[Valid] = [ Valid( """ from typing import Optional a: Optional[str] if a is not None: pass """, ), Valid( """ a: bool if a: pass """, ), ] INVALID: List[Invalid] = [ Invalid( code=""" from typing import Optional a: Optional[str] = None if a: pass """, expected_replacement=""" from typing import Optional a: Optional[str] = None if a is not None: pass """, ), Invalid( code=""" from typing import Optional a: Optional[str] = None x: bool = False if x and a: ... """, expected_replacement=""" from typing import Optional a: Optional[str] = None x: bool = False if x and a is not None: ... """, ), Invalid( code=""" from typing import Optional a: Optional[str] = None x: bool = False if a and x: ... """, expected_replacement=""" from typing import Optional a: Optional[str] = None x: bool = False if a is not None and x: ... """, ), Invalid( code=""" from typing import Optional a: Optional[str] = None x: bool = not a """, expected_replacement=""" from typing import Optional a: Optional[str] = None x: bool = a is None """, ), Invalid( code=""" from typing import Optional a: Optional[str] x: bool if x or a: pass """, expected_replacement=""" from typing import Optional a: Optional[str] x: bool if x or a is not None: pass """, ), Invalid( code=""" from typing import Optional a: Optional[str] x: bool if x: pass elif a: pass """, expected_replacement=""" from typing import Optional a: Optional[str] x: bool if x: pass elif a is not None: pass """, ), Invalid( code=""" from typing import Optional a: Optional[str] = None b: Optional[str] = None if a: pass elif b: pass """, expected_replacement=""" from typing import Optional a: Optional[str] = None b: Optional[str] = None if a is not None: pass elif b is not None: pass """, ), ] def leave_If(self, original_node: cst.If) -> None: changes: Dict[str, cst.CSTNode] = {} test_expression: cst.BaseExpression = original_node.test if m.matches(test_expression, m.Name()): # We are inside a simple check such as "if x". test_expression = cast(cst.Name, test_expression) if self._is_optional_type(test_expression): # We want to replace "if x" with "if x is not None". replacement_comparison: cst.Comparison = self._gen_comparison_to_none( variable_name=test_expression.value, operator=cst.IsNot() ) changes["test"] = replacement_comparison orelse = original_node.orelse if orelse is not None and m.matches(orelse, m.If()): # We want to catch this case upon leaving an `If` node so that we generate an `elif` statement correctly. # We check if the orelse node was reported, and if so, remove the report and generate a new report on # the current parent `If` node. new_reports = [] orelse_report: Optional[CstLintRuleReport] = None for report in self.context.reports: if isinstance(report, CstLintRuleReport): # Check whether the lint rule code matches this lint rule's code so we don't remove another # lint rule's report. if report.node is orelse and report.code == self.__class__.__name__: orelse_report = report else: new_reports.append(report) else: new_reports.append(report) if orelse_report is not None: self.context.reports = new_reports replacement_orelse = orelse_report.replacement_node changes["orelse"] = cst.ensure_type(replacement_orelse, cst.CSTNode) if changes: self.report( original_node, replacement=original_node.with_changes(**changes) ) def visit_BooleanOperation(self, node: cst.BooleanOperation) -> None: left_expression: cst.BaseExpression = node.left right_expression: cst.BaseExpression = node.right if m.matches(node.left, m.Name()): # Eg: "x and y". left_expression = cast(cst.Name, left_expression) if self._is_optional_type(left_expression): replacement_comparison = self._gen_comparison_to_none( variable_name=left_expression.value, operator=cst.IsNot() ) self.report( node, replacement=node.with_changes(left=replacement_comparison) ) if m.matches(right_expression, m.Name()): # Eg: "x and y". right_expression = cast(cst.Name, right_expression) if self._is_optional_type(right_expression): replacement_comparison = self._gen_comparison_to_none( variable_name=right_expression.value, operator=cst.IsNot() ) self.report( node, replacement=node.with_changes(right=replacement_comparison) ) def visit_UnaryOperation(self, node: cst.UnaryOperation) -> None: if m.matches(node, m.UnaryOperation(operator=m.Not(), expression=m.Name())): # Eg: "not x". expression: cst.Name = cast(cst.Name, node.expression) if self._is_optional_type(expression): replacement_comparison = self._gen_comparison_to_none( variable_name=expression.value, operator=cst.Is() ) self.report(node, replacement=replacement_comparison) def _is_optional_type(self, node: cst.Name) -> bool: reported_type = self.get_metadata(TypeInferenceProvider, node, None) # We want to use `startswith()` here since the type data will take on the form 'typing.Optional[SomeType]'. if reported_type is not None and reported_type.startswith("typing.Optional"): return True return False def _gen_comparison_to_none( self, variable_name: str, operator: Union[cst.Is, cst.IsNot] ) -> cst.Comparison: return cst.Comparison( left=cst.Name(value=variable_name), comparisons=[ cst.ComparisonTarget( operator=operator, comparator=cst.Name(value="None") ) ], )
class NamingConventionRule(CstLintRule): METADATA_DEPENDENCIES = (QualifiedNameProvider, ) # type: ignore VALID = [ Valid("type_hint: str"), Valid("type_hint_var: int = 5"), Valid("CONSTANT_WITH_UNDERSCORE12 = 10"), Valid("hello = 'world'"), Valid("snake_case = 'assign'"), Valid("for iteration in range(5): pass"), Valid("class _PrivateClass: pass"), Valid("class SomeClass: pass"), Valid("class One: pass"), Valid("def oneword(): pass"), Valid("def some_extra_words(): pass"), Valid("all = names_are = valid_in_multiple_assign = 5"), Valid("(walrus := 'operator')"), Valid("multiple, valid, assignments = 1, 2, 3"), Valid(""" class Spam: def __init__(self, valid, another_valid): self.valid = valid self.another_valid = another_valid self._private = None self.__extreme_private = None def bar(self): # This is just to test that the access is not being tested. return self.some_Invalid_NaMe """), Valid(""" from typing import List from collections import namedtuple Matrix = List[int] Point = namedtuple("Point", "x, y") some_matrix: Matrix = [1, 2] """), ] INVALID = [ Invalid("type_Hint_Var: int = 5"), Invalid("hellO = 'world'"), Invalid("ranDom_UpPercAse = 'testing'"), Invalid("for RandomCaps in range(5): pass"), Invalid("class _Invalid_PrivateClass: pass"), Invalid("class _invalidPrivateClass: pass"), Invalid("class lowerPascalCase: pass"), Invalid("class all_lower_case: pass"), Invalid("def oneWordInvalid(): pass"), Invalid("def Pascal_Case(): pass"), Invalid("valid = another_valid = Invalid = 5"), Invalid("(waLRus := 'operator')"), Invalid("def func(invalidParam, valid_param): pass"), Invalid("multiple, inValid, assignments = 1, 2, 3"), Invalid("[inside, list, inValid] = Invalid, 2, 3"), Invalid(""" class Spam: def __init__(self, foo, bar): self.foo = foo self._Bar = bar """), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self._assigntarget_counter: int = 0 def visit_Assign(self, node: cst.Assign) -> None: metadata: Optional[Collection[QualifiedName]] = self.get_metadata( QualifiedNameProvider, node.value, None) if metadata is not None: for qualname in metadata: # If the assignment is done with some objects from the typing or # collections module, then we will skip the check as the assignment # could be a type alias or the variable could be a class made using # ``collections.namedtuple``. if qualname.name.startswith(("typing", "collections")): return None for target_node in node.targets: if m.matches(target_node, m.AssignTarget(target=m.Name())): nodename = cst.ensure_type(target_node.target, cst.Name).value self._validate_nodename(node, nodename, NamingConvention.SNAKE_CASE) def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # The assignment value is optional, as it is possible to annotate an # expression without assigning to it: ``var: int`` if m.matches( node, m.AnnAssign( target=m.Name(), value=m.MatchIfTrue(lambda value: value is not None)), ): nodename = cst.ensure_type(node.target, cst.Name).value self._validate_nodename(node, nodename, NamingConvention.SNAKE_CASE) def visit_AssignTarget(self, node: cst.AssignTarget) -> None: self._assigntarget_counter += 1 def leave_AssignTarget(self, node: cst.AssignTarget) -> None: self._assigntarget_counter -= 1 def visit_ClassDef(self, node: cst.ClassDef) -> None: self._validate_nodename(node, node.name.value, NamingConvention.CAMEL_CASE) def visit_Attribute(self, node: cst.Attribute) -> None: # The attribute node can come through other context as well but we only care # about the ones coming from assignments. if self._assigntarget_counter > 0: # We only care about assignment attribute to *self*. if m.matches(node, m.Attribute(value=m.Name(value="self"))): self._validate_nodename(node, node.attr.value, NamingConvention.SNAKE_CASE) def visit_Element(self, node: cst.Element) -> None: # We only care about elements in *List* or *Tuple* specifically coming from # inside the multiple assignments. if self._assigntarget_counter > 0: if m.matches(node, m.Element(value=m.Name())): nodename = cst.ensure_type(node.value, cst.Name).value self._validate_nodename(node, nodename, NamingConvention.SNAKE_CASE) def visit_For(self, node: cst.For) -> None: if m.matches(node, m.For(target=m.Name())): nodename = cst.ensure_type(node.target, cst.Name).value self._validate_nodename(node, nodename, NamingConvention.SNAKE_CASE) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self._validate_nodename(node, node.name.value, NamingConvention.SNAKE_CASE) def visit_NamedExpr(self, node: cst.NamedExpr) -> None: if m.matches(node, m.NamedExpr(target=m.Name())): nodename = cst.ensure_type(node.target, cst.Name).value self._validate_nodename(node, nodename, NamingConvention.SNAKE_CASE) def visit_Param(self, node: cst.Param) -> None: self._validate_nodename(node, node.name.value, NamingConvention.SNAKE_CASE) def _validate_nodename(self, node: cst.CSTNode, nodename: str, naming_convention: NamingConvention) -> None: """Validate the provided *nodename* as per the given *naming_convention*. This is a convenience method as the same steps will be repeated for every visit functions which are to validate the name and report if found invalid. """ if not naming_convention.valid(nodename): self.report(node, naming_convention.value.format(nodename=nodename))
class AvoidOrInExceptRule(CstLintRule): """ Discourages use of ``or`` in except clauses. If an except clause needs to catch multiple exceptions, they must be expressed as a parenthesized tuple, for example: ``except (ValueError, TypeError)`` (https://docs.python.org/3/tutorial/errors.html#handling-exceptions) When ``or`` is used, only the first operand exception type of the conditional statement will be caught. For example:: In [1]: class Exc1(Exception): ...: pass ...: In [2]: class Exc2(Exception): ...: pass ...: In [3]: try: ...: raise Exception() ...: except Exc1 or Exc2: ...: print("caught!") ...: --------------------------------------------------------------------------- Exception Traceback (most recent call last) <ipython-input-3-3340d66a006c> in <module> 1 try: ----> 2 raise Exception() 3 except Exc1 or Exc2: 4 print("caught!") 5 Exception: In [4]: try: ...: raise Exc1() ...: except Exc1 or Exc2: ...: print("caught!") ...: caught! In [5]: try: ...: raise Exc2() ...: except Exc1 or Exc2: ...: print("caught!") ...: --------------------------------------------------------------------------- Exc2 Traceback (most recent call last) <ipython-input-5-5d29c1589cc0> in <module> 1 try: ----> 2 raise Exc2() 3 except Exc1 or Exc2: 4 print("caught!") 5 Exc2: """ MESSAGE: str = ( "Avoid using 'or' in an except block. For example:" + "'except ValueError or TypeError' only catches 'ValueError'. Instead, use " + "parentheses, 'except (ValueError, TypeError)'") VALID = [ Valid(""" try: print() except (ValueError, TypeError) as err: pass """) ] INVALID = [ Invalid( """ try: print() except ValueError or TypeError: pass """, ) ] def visit_Try(self, node: cst.Try) -> None: if m.matches( node, m.Try(handlers=[ m.ExceptHandler(type=m.BooleanOperation(operator=m.Or())) ]), ): self.report(node)
class NoTypedDictRule(CstLintRule): """ Enforce the use of ``dataclasses.dataclass`` decorator instead of ``NamedTuple`` for cleaner customization and inheritance. It supports default value, combining fields for inheritance, and omitting optional fields at instantiation. See `PEP 557 <https://www.python.org/dev/peps/pep-0557>`_. ``@dataclass`` is faster at reading an object's nested properties and executing its methods. (`benchmark <https://medium.com/@jacktator/dataclass-vs-namedtuple-vs-object-for-performance-optimization-in-python-691e234253b9>`_) """ MESSAGE: str = "Instead of TypedDict, consider using the @dataclass decorator from dataclasses instead for simplicity, efficiency and consistency." METADATA_DEPENDENCIES = (QualifiedNameProvider,) VALID = [ Valid( """ @dataclass(frozen=True) class Foo: pass """ ), Valid( """ @dataclass(frozen=False) class Foo: pass """ ), Valid( """ class Foo: pass """ ), Valid( """ class Foo(SomeOtherBase): pass """ ), Valid( """ @some_other_decorator class Foo: pass """ ), Valid( """ @some_other_decorator class Foo(SomeOtherBase): pass """ ), ] INVALID = [ Invalid( code=""" from typing import NamedTuple class Foo(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" from typing import NamedTuple as NT class Foo(NT): pass """, expected_replacement=""" from typing import NamedTuple as NT @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" import typing as typ class Foo(typ.NamedTuple): pass """, expected_replacement=""" import typing as typ @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" from typing import NamedTuple class Foo(NamedTuple, AnotherBase, YetAnotherBase): pass """, expected_replacement=""" from typing import NamedTuple @dataclass(frozen=True) class Foo(AnotherBase, YetAnotherBase): pass """, ), Invalid( code=""" from typing import NamedTuple class OuterClass(SomeBase): class InnerClass(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple class OuterClass(SomeBase): @dataclass(frozen=True) class InnerClass: pass """, ), Invalid( code=""" from typing import NamedTuple @some_other_decorator class Foo(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple @some_other_decorator @dataclass(frozen=True) class Foo: pass """, ), ] qualified_typeddict = QualifiedName(name="typing_extensionsTypedDict", source=QualifiedNameSource.IMPORT) def leave_ClassDef(self, original_node: cst.ClassDef) -> None: (namedtuple_base, new_bases) = self.partition_bases(original_node.bases) if namedtuple_base is not None: call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call) replacement = original_node.with_changes( lpar=MaybeSentinel.DEFAULT, rpar=MaybeSentinel.DEFAULT, bases=new_bases, decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)], ) self.report(original_node, replacement=replacement) def partition_bases(self, original_bases: Sequence[cst.Arg]) -> Tuple[Optional[cst.Arg], List[cst.Arg]]: # Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases namedtuple_base: Optional[cst.Arg] = None new_bases: List[cst.Arg] = [] for base_class in original_bases: if QualifiedNameProvider.has_name(self, base_class.value, self.qualified_typeddict): namedtuple_base = base_class else: new_bases.append(base_class) return (namedtuple_base, new_bases)
class RewriteToComprehensionRule(CstLintRule): """ A derivative of flake8-comprehensions's C400-C402 and C403-C404. Comprehensions are more efficient than functions calls. This C400-C402 suggest to use `dict/set/list` comprehensions rather than respective function calls whenever possible. C403-C404 suggest to remove unnecessary list comprehension in a set/dict call, and replace it with set/dict comprehension. """ VALID = [ Valid("[val for val in iterable]"), Valid("{val for val in iterable}"), Valid("{val: val+1 for val in iterable}"), # A function call is valid if the elt is a function that returns a tuple. Valid("dict(line.strip().split('=', 1) for line in attr_file)"), ] INVALID = [ Invalid( "list(val for val in iterable)", expected_replacement="[val for val in iterable]", ), # Nested list comprehenstion Invalid( "list(val for row in matrix for val in row)", expected_replacement="[val for row in matrix for val in row]", ), Invalid( "set(val for val in iterable)", expected_replacement="{val for val in iterable}", ), Invalid( "dict((x, f(x)) for val in iterable)", expected_replacement="{x: f(x) for val in iterable}", ), Invalid( "dict((x, y) for y, x in iterable)", expected_replacement="{x: y for y, x in iterable}", ), Invalid( "dict([val, val+1] for val in iterable)", expected_replacement="{val: val+1 for val in iterable}", ), Invalid( 'dict((x["name"], json.loads(x["data"])) for x in responses)', expected_replacement= '{x["name"]: json.loads(x["data"]) for x in responses}', ), # Nested dict comprehension Invalid( "dict((k, v) for k, v in iter for iter in iters)", expected_replacement="{k: v for k, v in iter for iter in iters}", ), Invalid( "set([val for val in iterable])", expected_replacement="{val for val in iterable}", ), Invalid( "dict([[val, val+1] for val in iterable])", expected_replacement="{val: val+1 for val in iterable}", ), Invalid( "dict([(x, f(x)) for x in foo])", expected_replacement="{x: f(x) for x in foo}", ), Invalid( "dict([(x, y) for y, x in iterable])", expected_replacement="{x: y for y, x in iterable}", ), Invalid( "set([val for row in matrix for val in row])", expected_replacement="{val for row in matrix for val in row}", ), ] def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
class RewriteToLiteralRule(CstLintRule): """ A derivative of flake8-comprehensions' C405-C406 and C409-C410. It's unnecessary to use a list or tuple literal within a call to tuple, list, set, or dict since there is literal syntax for these types. """ VALID = [ Valid("(1, 2)"), Valid("()"), Valid("[1, 2]"), Valid("[]"), Valid("{1, 2}"), Valid("set()"), Valid("{1: 2, 3: 4}"), Valid("{}"), ] INVALID = [ Invalid("tuple([1, 2])", expected_replacement="(1, 2)"), Invalid("tuple((1, 2))", expected_replacement="(1, 2)"), Invalid("tuple([])", expected_replacement="()"), Invalid("list([1, 2, 3])", expected_replacement="[1, 2, 3]"), Invalid("list((1, 2, 3))", expected_replacement="[1, 2, 3]"), Invalid("list([])", expected_replacement="[]"), Invalid("set([1, 2, 3])", expected_replacement="{1, 2, 3}"), Invalid("set((1, 2, 3))", expected_replacement="{1, 2, 3}"), Invalid("set([])", expected_replacement="set()"), Invalid( "dict([(1, 2), (3, 4)])", expected_replacement="{1: 2, 3: 4}", ), Invalid( "dict(((1, 2), (3, 4)))", expected_replacement="{1: 2, 3: 4}", ), Invalid( "dict([[1, 2], [3, 4], [5, 6]])", expected_replacement="{1: 2, 3: 4, 5: 6}", ), Invalid("dict([])", expected_replacement="{}"), Invalid("tuple()", expected_replacement="()"), Invalid("list()", expected_replacement="[]"), Invalid("dict()", expected_replacement="{}"), ] def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("tuple") | m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.List() | m.Tuple())], ), ) or m.matches( node, m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"), args=[]), ): pairs_matcher = m.ZeroOrMore( m.Element(m.Tuple( elements=[m.DoNotCare(), m.DoNotCare()])) | m.Element(m.List( elements=[m.DoNotCare(), m.DoNotCare()]))) exp = cst.ensure_type(node, cst.Call) call_name = cst.ensure_type(exp.func, cst.Name).value # If this is a empty call, it's an Unnecessary Call where we rewrite the call # to literal, except set(). if not exp.args: elements = [] message_formatter = UNNCESSARY_CALL else: arg = exp.args[0].value elements = cst.ensure_type( arg, cst.List if isinstance(arg, cst.List) else cst.Tuple).elements message_formatter = UNNECESSARY_LITERAL if call_name == "tuple": new_node = cst.Tuple(elements=elements) elif call_name == "list": new_node = cst.List(elements=elements) elif call_name == "set": # set() doesn't have an equivelant literal call. If it was # matched here, it's an unnecessary literal suggestion. if len(elements) == 0: self.report( node, UNNECESSARY_LITERAL.format(func=call_name), replacement=node.deep_replace( node, cst.Call(func=cst.Name("set"))), ) return new_node = cst.Set(elements=elements) elif len(elements) == 0 or m.matches( exp.args[0].value, m.Tuple(elements=[pairs_matcher]) | m.List(elements=[pairs_matcher]), ): new_node = cst.Dict(elements=[( lambda val: cst.DictElement(val.elements[ 0].value, val.elements[1].value))(cst.ensure_type( ele.value, cst.Tuple if isinstance(ele.value, cst.Tuple ) else cst.List, )) for ele in elements]) else: # Unrecoginized form return self.report( node, message_formatter.format(func=call_name), replacement=node.deep_replace(node, new_node), )
class NoAssertTrueForComparisonsRule(CstLintRule): """ Finds incorrect use of ``assertTrue`` when the intention is to compare two values. These calls are replaced with ``assertEqual``. Comparisons with True, False and None are replaced with one-argument calls to ``assertTrue``, ``assertFalse`` and ``assertIsNone``. """ MESSAGE: str = '"assertTrue" does not compare its arguments, use "assertEqual" or other ' + "appropriate functions." VALID = [ Valid("self.assertTrue(a == b)"), Valid('self.assertTrue(data.is_valid(), "is_valid() method")'), Valid("self.assertTrue(validate(len(obj.getName(type=SHORT))))"), Valid("self.assertTrue(condition, message_string)"), ] INVALID = [ Invalid("self.assertTrue(a, 3)", expected_replacement="self.assertEqual(a, 3)"), Invalid( "self.assertTrue(hash(s[:4]), 0x1234)", expected_replacement="self.assertEqual(hash(s[:4]), 0x1234)", ), Invalid( "self.assertTrue(list, [1, 3])", expected_replacement="self.assertEqual(list, [1, 3])", ), Invalid( "self.assertTrue(optional, None)", expected_replacement="self.assertIsNone(optional)", ), Invalid( "self.assertTrue(b == a, True)", expected_replacement="self.assertTrue(b == a)", ), Invalid( "self.assertTrue(b == a, False)", expected_replacement="self.assertFalse(b == a)", ), ] def visit_Call(self, node: cst.Call) -> None: result = m.extract( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.DoNotCare(), m.Arg(value=m.SaveMatchedNode( m.OneOf( m.Integer(), m.Float(), m.Imaginary(), m.Tuple(), m.List(), m.Set(), m.Dict(), m.Name("None"), m.Name("True"), m.Name("False"), ), "second", )), ], ), ) if result: second_arg = result["second"] if isinstance(second_arg, Sequence): second_arg = second_arg[0] if m.matches(second_arg, m.Name("True")): new_call = node.with_changes(args=[ node.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT) ], ) elif m.matches(second_arg, m.Name("None")): new_call = node.with_changes( func=node.func.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertIsNone", ), args=[ node.args[0].with_changes( comma=cst.MaybeSentinel.DEFAULT) ], ) elif m.matches(second_arg, m.Name("False")): new_call = node.with_changes( func=node.func.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertFalse", ), args=[ node.args[0].with_changes( comma=cst.MaybeSentinel.DEFAULT) ], ) else: new_call = node.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertEqual", ) self.report(node, replacement=new_call)
class NoStaticIfConditionRule(CstLintRule): """ Discourages ``if`` conditions which evaluate to a static value (e.g. ``or True``, ``and False``, etc). """ MESSAGE: str = ( "Your if condition appears to evaluate to a static value (e.g. `or True`, `and False`). " + "Please double check this logic and if it is actually temporary debug code." ) VALID = [ Valid(""" if my_func() or not else_func(): pass """), Valid(""" if function_call(True): pass """), Valid(""" # ew who would this??? def true(): return False if true() and else_call(): # True or False pass """), Valid(""" # ew who would this??? if False or some_func(): pass """), ] INVALID = [ Invalid( """ if True: do_something() """, ), Invalid( """ if crazy_expression or True: do_something() """, ), Invalid( """ if crazy_expression and False: do_something() """, ), Invalid( """ if crazy_expression and not True: do_something() """, ), Invalid( """ if crazy_expression or not False: do_something() """, ), Invalid( """ if crazy_expression or (something() or True): do_something() """, ), Invalid( """ if crazy_expression and (something() and (not True)): do_something() """, ), Invalid( """ if crazy_expression and (something() and (other_func() and not True)): do_something() """, ), Invalid( """ if (crazy_expression and (something() and (not True))) or True: do_something() """, ), Invalid( """ async def some_func() -> none: if (await expression()) and False: pass """, ), ] @classmethod def _extract_static_bool(cls, node: cst.BaseExpression) -> Optional[bool]: if m.matches(node, m.Call()): # cannot reason about function calls return None if m.matches(node, m.UnaryOperation(operator=m.Not())): sub_value = cls._extract_static_bool( cst.ensure_type(node, cst.UnaryOperation).expression) if sub_value is None: return None return not sub_value if m.matches(node, m.Name("True")): return True if m.matches(node, m.Name("False")): return False if m.matches(node, m.BooleanOperation()): node = cst.ensure_type(node, cst.BooleanOperation) left_value = cls._extract_static_bool(node.left) right_value = cls._extract_static_bool(node.right) if m.matches(node.operator, m.Or()): if right_value is True or left_value is True: return True if m.matches(node.operator, m.And()): if right_value is False or left_value is False: return False return None def visit_If(self, node: cst.If) -> None: if self._extract_static_bool(node.test) in {True, False}: self.report(node)
class UseTypesFromTypingRule(CstLintRule): """ Enforces the use of types from the ``typing`` module in type annotations in place of ``builtins.{builtin_type}`` since the type system doesn't recognize the latter as a valid type. """ METADATA_DEPENDENCIES = ( QualifiedNameProvider, ScopeProvider, ) VALID = [ Valid( """ def fuction(list: List[str]) -> None: pass """ ), Valid( """ def function() -> None: thing: Dict[str, str] = {} """ ), Valid( """ def function() -> None: thing: Tuple[str] """ ), Valid( """ from typing import Dict, List def function() -> bool: return Dict == List """ ), Valid( """ from typing import List as list from graphene import List def function(a: list[int]) -> List[int]: return [] """ ), ] INVALID = [ Invalid( """ from typing import List def whatever(list: list[str]) -> None: pass """, expected_replacement=""" from typing import List def whatever(list: List[str]) -> None: pass """, ), Invalid( """ def function(list: list[str]) -> None: pass """, ), Invalid( """ def func() -> None: thing: dict[str, str] = {} """, ), Invalid( """ def func() -> None: thing: tuple[str] """, ), Invalid( """ from typing import Dict def func() -> None: thing: dict[str, str] = {} """, expected_replacement=""" from typing import Dict def func() -> None: thing: Dict[str, str] = {} """, ), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self.annotation_counter: int = 0 def visit_Annotation(self, node: libcst.Annotation) -> None: self.annotation_counter += 1 def leave_Annotation(self, original_node: libcst.Annotation) -> None: self.annotation_counter -= 1 def visit_Name(self, node: libcst.Name) -> None: # Avoid a false-positive in this scenario: # # ``` # from typing import List as list # from graphene import List # ``` qualified_names = self.get_metadata(QualifiedNameProvider, node, set()) is_builtin_type = node.value in BUILTINS_TO_REPLACE and all(qualified_name.name in QUALIFIED_BUILTINS_TO_REPLACE for qualified_name in qualified_names) if self.annotation_counter > 0 and is_builtin_type: correct_type = node.value.title() scope = self.get_metadata(ScopeProvider, node) replacement = None if scope is not None and correct_type in scope: replacement = node.with_changes(value=correct_type) self.report( node, REPLACE_BUILTIN_TYPE_ANNOTATION.format(builtin_type=node.value, correct_type=correct_type), replacement=replacement, )
class UseClassNameAsCodeRule(CstLintRule): """ Meta lint rule which checks that codes of lint rules are migrated to new format in lint rule class definitions. """ MESSAGE = "`IG`-series codes are deprecated. Use class name as code instead." VALID = [ Valid( """ MESSAGE = "This is a message" """ ), Valid( """ from fixit.common.base import CstLintRule class FakeRule(CstLintRule): MESSAGE = "This is a message" """ ), Valid( """ from fixit.common.base import CstLintRule class FakeRule(CstLintRule): INVALID = [ Invalid( code="" ) ] """ ), ] INVALID = [ Invalid( code=""" MESSAGE = "IG90000 Message" """, expected_replacement=""" MESSAGE = "Message" """, ), Invalid( code=""" from fixit.common.base import CstLintRule class FakeRule(CstLintRule): INVALID = [ Invalid( code="", kind="IG000" ) ] """, expected_replacement=""" from fixit.common.base import CstLintRule class FakeRule(CstLintRule): INVALID = [ Invalid( code="", ) ] """, ), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self.inside_invalid_call: bool = False def visit_SimpleString(self, node: cst.SimpleString) -> None: matched = re.match(r"^(\'|\")(?P<igcode>IG\d+ )\S", node.value) if matched is not None: replacement_string = node.value.replace(matched.group("igcode"), "", 1) self.report( node, self.MESSAGE, replacement=node.with_changes(value=replacement_string), ) def visit_Call(self, node: cst.Call) -> None: func = node.func if m.matches(func, m.Name()): func = cast(cst.Name, func) if func.value == "Invalid": self.inside_invalid_call = True def leave_Call(self, original_node: cst.Call) -> None: func = original_node.func if m.matches(func, m.Name()): func = cast(cst.Name, func) if func.value == "Invalid": self.inside_invalid_call = False def visit_Arg(self, node: cst.Arg) -> None: # Remove `kind` arguments to Invalid test cases as they are no longer needed if self.inside_invalid_call: arg_value = node.value if m.matches(arg_value, m.SimpleString()): arg_value = cast(cst.SimpleString, arg_value) string_value = arg_value.value matched = re.match(r"^(\'|\")(?P<igcode>IG\d+)(\'|\")\Z", string_value) if matched: self.report( node, self.MESSAGE, replacement=cst.RemovalSentinel.REMOVE, )
class ImportConstraintsRule(CstLintRule): """ Rule to impose import constraints in certain directories to improve runtime performance. The directories specified in the ImportConstraintsRule setting in the ``.fixit.config.yaml`` file's ``rule_config`` section can impose import constraints for that directory and its children as follows:: rule_config: ImportConstraintsRule: dir_under_repo_root: rules: [ ["module_under_repo_root", "allow"], ["another_module_under_repo_root, "deny"], ["*", "deny"] ] ignore_tests: True ignore_types: True Each rule under ``rules`` is evaluated in order from top to bottom and the last rule for each directory should be a wildcard rule. ``ignore_tests`` and `ignore_types` should carry boolean values and can be omitted. They are both set to `True` by default. If ``ignore_types`` is True, this rule will ignore imports inside ``if TYPE_CHECKING`` blocks since those imports do not have an affect on runtime performance. If ``ignore_tests`` is True, this rule will not lint any files found in a testing module. """ _config: Optional[_ImportConfig] _repo_root: Path _type_checking_stack: List[cst.If] _abs_file_path: Path MESSAGE: str = ( "According to the settings for this directory in the .fixit.config.yaml configuration file, " + "{imported} cannot be imported from within {current_file}. " ) VALID = [ # Everything is allowed Valid("import common"), Valid( "import common", config=_gen_testcase_config({"some_dir": {"rules": [["*", "allow"]]}}), filename="some_dir/file.py", ), # This import is allowlisted Valid( "import common", config=_gen_testcase_config( {"some_dir": {"rules": [["common", "allow"], ["*", "deny"]]}} ), filename="some_dir/file.py", ), # Allow children of a allowlisted module Valid( "from common.foo import bar", config=_gen_testcase_config( {"some_dir": {"rules": [["common", "allow"], ["*", "deny"]]}} ), filename="some_dir/file.py", ), # Validate rules are evaluted in order Valid( "from common.foo import bar", config=_gen_testcase_config( { "some_dir": { "rules": [ ["common.foo.bar", "allow"], ["common", "deny"], ["*", "deny"], ] } } ), filename="some_dir/file.py", ), # Built-in modules are fine Valid( "import ast", config=_gen_testcase_config({"some_dir": {"rules": [["*", "deny"]]}}), filename="some_dir/file.py", ), # Relative imports Valid( "from . import module", config=_gen_testcase_config( {".": {"rules": [["common.safe", "allow"], ["*", "deny"]]}} ), filename="common/safe/file.py", ), Valid( "from ..safe import module", config=_gen_testcase_config( {"common": {"rules": [["common.safe", "allow"], ["*", "deny"]]}} ), filename="common/unsafe/file.py", ), # Ignore some relative module that leaves the repo root Valid( "from ....................................... import module", config=_gen_testcase_config({".": {"rules": [["*", "deny"]]}}), filename="file.py", ), # File belongs to more than one directory setting (should enforce closest parent directory) Valid( "from common.foo import bar", config=_gen_testcase_config( { "dir_1/dir_2": { "rules": [["common.foo.bar", "allow"], ["*", "deny"]] }, "dir_1": {"rules": [["common.foo.bar", "deny"], ["*", "deny"]]}, } ), filename="dir_1/dir_2/file.py", ), # File belongs to more than one directory setting, flipped order (should enforce closest parent directory) Valid( "from common.foo import bar", config=_gen_testcase_config( { "dir_1": {"rules": [["common.foo.bar", "deny"], ["*", "deny"]]}, "dir_1/dir_2": { "rules": [["common.foo.bar", "allow"], ["*", "deny"]] }, } ), filename="dir_1/dir_2/file.py", ), ] INVALID = [ # Everything is denied Invalid( "import common", config=_gen_testcase_config({"some_dir": {"rules": [["*", "deny"]]}}), filename="some_dir/file.py", ), # Validate rules are evaluated in order Invalid( "from common.foo import bar", config=_gen_testcase_config( { "some_dir": { "rules": [ ["common.foo.bar", "deny"], ["common", "allow"], ["*", "allow"], ] } } ), filename="some_dir/file.py", ), # We should match against the real name, not the aliased name Invalid( "import common as not_common", config=_gen_testcase_config( {"some_dir": {"rules": [["common", "deny"], ["*", "allow"]]}} ), filename="some_dir/file.py", ), Invalid( "from common import bar as not_bar", config=_gen_testcase_config( {"some_dir": {"rules": [["common.bar", "deny"], ["*", "allow"]]}} ), filename="some_dir/file.py", ), # Relative imports Invalid( "from . import b", config=_gen_testcase_config({"common": {"rules": [["*", "deny"]]}}), filename="common/a.py", ), # File belongs to more than one directory setting, import from Invalid( "from common.foo import bar", config=_gen_testcase_config( { "dir_1/dir_2": {"rules": [["*", "deny"]]}, "dir_1": { "rules": [["common.foo.bar", "allow"], ["*", "deny"]], }, } ), filename="dir_1/dir_2/file.py", ), # File belongs to more than one directory setting, import Invalid( "import common", config=_gen_testcase_config( { "dir_1/dir_2": {"rules": [["*", "deny"]]}, "dir_1": { "rules": [["common", "allow"], ["*", "deny"]], }, } ), filename="dir_1/dir_2/file.py", ), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self._repo_root = Path(self.context.config.repo_root).resolve() self._config = None self._abs_file_path = (self._repo_root / context.file_path).resolve() import_constraints_config = self.context.config.rule_config.get( self.__class__.__name__, None ) if import_constraints_config is not None: formatted_config: Dict[ Path, Dict[object, object] ] = self._parse_and_format_config(import_constraints_config) # Run through logical ancestors of the filepath stopping early if a parent # directory is found in the config. The closest parent's settings will be used. for parent_dir in self._abs_file_path.parents: if parent_dir in formatted_config: settings_for_dir = formatted_config[parent_dir] self._config = _ImportConfig.from_config(settings_for_dir) break self._type_checking_stack = [] def _parse_and_format_config( self, import_constraints_config: Dict[str, object] ) -> Dict[Path, Dict[object, object]]: # Normalizes paths, and converts all paths to absolute paths using the specified repo_root. formatted_config: Dict[Path, Dict[object, object]] = {} for dirname, dir_settings in import_constraints_config.items(): abs_dirpath: Optional[Path] = None # If it's an absolute path, make sure it's relative to repo_root (which should be an absolute path). if os.path.isabs(dirname): if dirname.startswith(str(self._repo_root)): abs_dirpath = Path(dirname) # Otherwise assume all relative paths exist under repo_root, and don't add paths that leave repo_root (eg: '../path') else: abs_dirname = os.path.normpath(os.path.join(self._repo_root, dirname)) if abs_dirname.startswith(str(self._repo_root)): abs_dirpath = Path(abs_dirname) if not isinstance(dir_settings, dict): raise ValueError( f"Invalid entry `{dir_settings}`.\n" + "You must specify settings in key-value format under a directory." ) if abs_dirpath is not None: formatted_config[abs_dirpath] = dir_settings return formatted_config def should_skip_file(self) -> bool: config = self._config return config is None or (config.ignore_tests and self.context.in_tests) def visit_If(self, node: cst.If) -> None: # TODO: Handle stuff like typing.TYPE_CHECKING test = node.test if isinstance(test, cst.Name) and test.value == "TYPE_CHECKING": self._type_checking_stack.append(node) def leave_If(self, original_node: cst.If) -> None: if self._type_checking_stack and self._type_checking_stack[-1] is original_node: self._type_checking_stack.pop() def visit_Import(self, node: cst.Import) -> None: self._check_names( node, (get_full_name_for_node_or_raise(alias.name) for alias in node.names) ) def visit_ImportFrom(self, node: cst.ImportFrom) -> None: module = node.module abs_module = self._to_absolute_module( get_full_name_for_node_or_raise(module) if module is not None else "", len(node.relative), ) if abs_module is not None: names = node.names if isinstance(names, Sequence): self._check_names( node, (f"{abs_module}.{alias.name.value}" for alias in names) ) def _to_absolute_module(self, module: Optional[str], level: int) -> Optional[str]: if level == 0: return module # Get the absolute path of the file current_dir = self._abs_file_path for __ in range(level): current_dir = current_dir.parent if ( current_dir != self._repo_root and self._repo_root not in current_dir.parents ): return None prefix = ".".join(current_dir.relative_to(self._repo_root).parts) return f"{prefix}.{module}" if module is not None else prefix def _check_names(self, node: cst.CSTNode, names: Iterable[str]) -> None: config = self._config if config is None or (config.ignore_types and self._type_checking_stack): return for name in names: if name.split(".", 1)[0] not in _get_local_roots(self._repo_root): continue rule = config.match(name) if not rule.allow: self.report( node, self.MESSAGE.format( imported=name, current_file=self.context.file_path ), )
class UsePrintfLoggingRule(CstLintRule): MESSAGE: str = ( "UsePrintfLoggingRule: Use %s style strings instead of f-string or format() " + "for python logging. Pass %s values as arguments or in the 'extra' argument " + "in loggers." ) VALID = [ # Printf strings in log statements are best practice Valid('logging.error("Hello %s", my_val)'), Valid('logging.info("Hello %(my_str)s!", {"my_str": "world"})'), Valid( """ logger = logging.getLogger() logger.log(logging.DEBUG, "Concat " "Printf string %s", vals) logger.debug("Hello %(my_str)s!", {"my_str": "world"}) """, ), Valid( """ mylog = logging.getLogger() mylog.warning("A printf string %s, %d", 'George', 732) """, ), # Don't report if logger isn't a logging.getLogger Valid('logger.error("Hello %s", my_val)'), Valid( """ logger = custom_logger.getLogger() logger.error("Hello %s", my_val) """, ), # fstrings should be allowed elsewhere Valid( """ logging.warning("simple error %s", test) fn(f"formatted string {my_var}") """, ), Valid( """ logger: logging.Logger = logging.getLogger() logger.warning("simple error %s", test) test_var = f"test string {msg}" func(3, other_var, "string format {}".format(test_var)) """, ), # %s interpolation allowed outside of log calls Valid('test = "my %s" % var'), ] INVALID = [ # Using fstring in a log Invalid('logging.error(f"Hello {my_var}")'), Invalid( """ logger = logging.getLogger() logger.error("Cat" f"Hello {my_var}") """, ), # Using str.format() in a log Invalid('logging.info("Hello {}".format(my_str))'), # Also invalid to use either in loggers Invalid( """ logger: logging.Logger = logging.getLogger() logger.log("DEBUG", f"Format string {vals}") """, ), Invalid( """ log = logging.getLogger() log.warning("This string formats {}".format(foo)) """, ), # Do not interpolate %s strings in log Invalid('logging.error("my error: %s" % msg)'), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self.logging_stack: Set[cst.Call] = set() self.logger_names: Set[str] = {"logging"} def visit_Assign(self, node: cst.Assign) -> None: # Store the assignment of logger = logging.getLogger() if check_getLogger(node): target = node.targets[0].target if m.matches(target, m.Name()): self.logger_names.add(cst.ensure_type(target, cst.Name).value) def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # Store the assignment of logger: logging.Logger = logging.getLogger() if check_getLogger(node): if m.matches(node.target, m.Name()): self.logger_names.add(cst.ensure_type(node.target, cst.Name).value) def visit_Call(self, node: cst.Call) -> None: # Record if we are in a call to a log function if match_logger_calls(node, self.logger_names): self.logging_stack.add(node) # Check and report calls to str.format() in calls to log functions if self.logging_stack and match_calls_str_format(node): self.report(node) def leave_Call(self, original_node: cst.Call) -> None: # Record leaving a call to a log function if original_node in self.logging_stack: self.logging_stack.remove(original_node) def visit_FormattedString(self, node: cst.FormattedString) -> None: # Report if using a formatted string inside a logging call if self.logging_stack: self.report(node) def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: if not self.logging_stack: return if m.matches( node, m.BinaryOperation( left=m.OneOf(m.SimpleString(), m.ConcatenatedString()), operator=m.Modulo(), ), ): self.report(node)
class UseFstringRule(CstLintRule): """ Encourages the use of f-string instead of %-formatting or .format() for high code quality and efficiency. Following two cases not covered: 1. arguments length greater than 30 characters: for better readibility reason For example: 1: this is the answer: %d" % (a_long_function_call() + b_another_long_function_call()) 2: f"this is the answer: {a_long_function_call() + b_another_long_function_call()}" 3: result = a_long_function_call() + b_another_long_function_call() f"this is the answer: {result}" Line 1 is more readable than line 2. Ideally, we’d like developers to manually fix this case to line 3 2. only %s placeholders are linted against for now. We leave it as future work to support other placeholders. For example, %d raises TypeError for non-numeric objects, whereas f“{x:d}” raises ValueError. This discrepancy in the type of exception raised could potentially break the logic in the code where the exception is handled """ MESSAGE: str = ( "Do not use printf style formatting or .format(). " + "Use f-string instead to be more readable and efficient. " + "See https://www.python.org/dev/peps/pep-0498/") VALID = [ Valid("somebody='you'; f\"Hey, {somebody}.\""), Valid('"hey"'), Valid('"hey" + "there"'), Valid('b"a type %s" % var'), Valid('logging.error("printf style logging %s", my_var)'), ] INVALID = [ Invalid('"Hey, {somebody}.".format(somebody="you")'), Invalid('"%s" % "hi"', expected_replacement='''f"{'hi'}"'''), Invalid('"a name: %s" % name', expected_replacement='f"a name: {name}"'), Invalid( '"an attribute %s ." % obj.attr', expected_replacement='f"an attribute {obj.attr} ."', ), Invalid( 'r"raw string value=%s" % val', expected_replacement='fr"raw string value={val}"', ), Invalid('"{%s}" % val', expected_replacement='f"{{{val}}}"'), Invalid('"{%s" % val', expected_replacement='f"{{{val}"'), Invalid( '"The type of var: %s" % type(var)', expected_replacement='f"The type of var: {type(var)}"', ), Invalid( '"%s" % obj.this_is_a_very_long_expression(parameter)["a_very_long_key"]', ), Invalid( '"type of var: %s, value of var: %s" % (type(var), var)', expected_replacement= 'f"type of var: {type(var)}, value of var: {var}"', ), Invalid( "'%s\" double quote is used' % var", expected_replacement="f'{var}\" double quote is used'", ), Invalid( '"var1: %s, var2: %s, var3: %s, var4: %s" % (class_object.attribute, dict_lookup["some_key"], some_module.some_function(), var4)', expected_replacement= '''f"var1: {class_object.attribute}, var2: {dict_lookup['some_key']}, var3: {some_module.some_function()}, var4: {var4}"''', ), Invalid( '"a list: %s" % " ".join(var)', expected_replacement='''f"a list: {' '.join(var)}"''', ), ] def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call(func=m.Attribute(value=m.SimpleString(), attr=m.Name(value="format"))), ): self.report(node) def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: expr_key = "expr" extracts = m.extract( node, m.BinaryOperation( left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), right=m.SaveMatchedNode( m.MatchIfTrue( _gen_match_simple_expression( self.context.wrapper.module)), expr_key, ), ), ) if extracts: expr = extracts[expr_key] parts = [] simple_string = cst.ensure_type(node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions = ([elm.value for elm in expr.elements] if isinstance( expr, cst.Tuple) else [expr]) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # Only generate warning for cases where %-string not comes with same number of elements in tuple self.report(node) return try: parts.append( cst.FormattedStringExpression(expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ))) except Exception: self.report(node) return token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" replacement = cst.FormattedString(parts=parts, start=start, end=simple_string.quote) self.report(node, replacement=replacement) elif m.matches( node, m.BinaryOperation( left=m.SimpleString(), operator=m.Modulo())) and isinstance( cst.ensure_type( node.left, cst.SimpleString).evaluated_value, str): self.report(node)
class SortedAttributesRule(CstLintRule): """ Ever wanted to sort a bunch of class attributes alphabetically? Well now it's easy! Just add "@sorted-attributes" in the doc string of a class definition and lint will automatically sort all attributes alphabetically. Feel free to add other methods and such -- it should only affect class attributes. """ INVALID = [ Invalid( """ class MyUnsortedConstants: \"\"\" @sorted-attributes \"\"\" z = "hehehe" B = 'aaa234' A = 'zzz123' cab = "foo bar" Daaa = "banana" @classmethod def get_foo(cls) -> str: return "some random thing" """, expected_replacement=""" class MyUnsortedConstants: \"\"\" @sorted-attributes \"\"\" A = 'zzz123' B = 'aaa234' Daaa = "banana" cab = "foo bar" z = "hehehe" @classmethod def get_foo(cls) -> str: return "some random thing" """, ) ] VALID = [ Valid(""" class MyConstants: \"\"\" @sorted-attributes \"\"\" A = 'zzz123' B = 'aaa234' class MyUnsortedConstants: B = 'aaa234' A = 'zzz123' """) ] MESSAGE: str = "It appears you are using the @sorted-attributes directive and the class variables are unsorted. See the lint autofix suggestion." def visit_ClassDef(self, node: cst.ClassDef) -> None: doc_string = node.get_docstring() if not doc_string or "@sorted-attributes" not in doc_string: return found_any_assign: bool = False pre_assign_lines: List[LineType] = [] assign_lines: List[LineType] = [] post_assign_lines: List[LineType] = [] def _add_unmatched_line(line: LineType) -> None: post_assign_lines.append( line) if found_any_assign else pre_assign_lines.append(line) for line in node.body.body: if m.matches( line, m.SimpleStatementLine( body=[m.Assign(targets=[m.AssignTarget()])])): found_any_assign = True assign_lines.append(line) else: _add_unmatched_line(line) continue sorted_assign_lines = sorted( assign_lines, key=lambda line: line.body[0].targets[0].target.value) if sorted_assign_lines == assign_lines: return self.report( node, replacement=node.with_changes(body=node.body.with_changes( body=pre_assign_lines + sorted_assign_lines + post_assign_lines)), )
class NoStringTypeAnnotationRule(CstLintRule): """ Enforce the use of type identifier instead of using string type hints for simplicity and better syntax highlighting. Starting in Python 3.7, ``from __future__ import annotations`` can postpone evaluation of type annotations `PEP 563 <https://www.python.org/dev/peps/pep-0563/#forward-references>`_ and thus forward references no longer need to use string annotation style. """ MESSAGE = "String type hints are no longer necessary in Python, use the type identifier directly." VALID = [ # Usage of a Class for instantiation and typing. Valid(""" from a.b import Class def foo() -> Class: return Class() """), Valid(""" import typing from a.b import Class def foo() -> typing.Type[Class]: return Class """), Valid(""" import typing from a.b import Class from c import func def foo() -> typing.Optional[typing.Type[Class]]: return Class if func() else None """), Valid(""" from a.b import Class def foo(arg: Class) -> None: pass foo(Class()) """), Valid(""" from a.b import Class module_var: Class = Class() """), Valid(""" from typing import Literal def foo() -> Literal["a", "b"]: return "a" """), Valid(""" import typing def foo() -> typing.Optional[typing.Literal["a", "b"]]: return "a" """), ] INVALID = [ # Using string type hints isn't needed Invalid( """ from __future__ import annotations from a.b import Class def foo() -> "Class": return Class() """, line=5, expected_replacement=""" from __future__ import annotations from a.b import Class def foo() -> Class: return Class() """, ), Invalid( """ from __future__ import annotations from a.b import Class async def foo() -> "Class": return await Class() """, line=5, expected_replacement=""" from __future__ import annotations from a.b import Class async def foo() -> Class: return await Class() """, ), Invalid( """ from __future__ import annotations import typing from a.b import Class def foo() -> typing.Type["Class"]: return Class """, line=6, expected_replacement=""" from __future__ import annotations import typing from a.b import Class def foo() -> typing.Type[Class]: return Class """, ), Invalid( """ from __future__ import annotations import typing from a.b import Class from c import func def foo() -> Optional[typing.Type["Class"]]: return Class if func() else None """, line=7, expected_replacement=""" from __future__ import annotations import typing from a.b import Class from c import func def foo() -> Optional[typing.Type[Class]]: return Class if func() else None """, ), Invalid( """ from __future__ import annotations from a.b import Class def foo(arg: "Class") -> None: pass foo(Class()) """, line=5, expected_replacement=""" from __future__ import annotations from a.b import Class def foo(arg: Class) -> None: pass foo(Class()) """, ), Invalid( """ from __future__ import annotations from a.b import Class module_var: "Class" = Class() """, line=5, expected_replacement=""" from __future__ import annotations from a.b import Class module_var: Class = Class() """, ), Invalid( """ from __future__ import annotations import typing from typing_extensions import Literal from a.b import Class def foo() -> typing.Tuple[Literal["a", "b"], "Class"]: return Class() """, line=7, expected_replacement=""" from __future__ import annotations import typing from typing_extensions import Literal from a.b import Class def foo() -> typing.Tuple[Literal["a", "b"], Class]: return Class() """, ), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self.in_annotation: Set[cst.Annotation] = set() self.in_literal: Set[cst.Subscript] = set() self.has_future_annotations_import = False def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if m.matches( node, m.ImportFrom( module=m.Name("__future__"), names=[ m.ZeroOrMore(), m.ImportAlias(name=m.Name("annotations")), m.ZeroOrMore(), ], ), ): self.has_future_annotations_import = True def visit_Annotation(self, node: cst.Annotation) -> None: self.in_annotation.add(node) def leave_Annotation(self, original_node: cst.Annotation) -> None: self.in_annotation.remove(original_node) def visit_Subscript(self, node: cst.Subscript) -> None: if not self.has_future_annotations_import: return if self.in_annotation: if m.matches( node, m.Subscript(metadata=m.MatchMetadataIfTrue( QualifiedNameProvider, lambda qualnames: any(n.name == "typing_extensions.Literal" for n in qualnames), )), metadata_resolver=self.context.wrapper, ): self.in_literal.add(node) def leave_Subscript(self, original_node: cst.Subscript) -> None: if not self.has_future_annotations_import: return if original_node in self.in_literal: self.in_literal.remove(original_node) def visit_SimpleString(self, node: cst.SimpleString) -> None: if not self.has_future_annotations_import: return if self.in_annotation and not self.in_literal: # This is not allowed past Python3.7 since it's no longer necessary. self.report( node, replacement=cst.parse_expression( node.evaluated_value, config=self.context.wrapper.module.config_for_parsing, ), )
class RequireDoctestRule(CstLintRule): VALID = [ # Module-level docstring contains doctest. Valid(""" ''' Module-level docstring contains doctest >>> foo() None ''' def foo(): pass class Bar: def baz(self): pass def bar(): pass """), # Module contains a test function. Valid(""" def foo(): pass def bar(): pass # Contains a test function def test_foo(): pass class Baz: def baz(self): pass def spam(): pass """), # Module contains multiple test function. Valid(""" def foo(): pass def bar(): pass def test_foo(): pass def test_bar(): pass class Baz: def baz(self): pass def spam(): pass """), # Module contains a test class. Valid(""" def foo(): pass class Baz: def baz(self): pass def bar(): pass # Contains a test class class TestSpam: def test_spam(self): pass def egg(): pass """), # Class level docstring contains doctest, so skip doctest checking only # for that class. Valid(""" def foo(): ''' >>> foo() ''' pass class Spam: ''' Class-level docstring contains doctest >>> Spam() ''' def foo(self): pass def spam(self): pass def bar(): ''' >>> bar() ''' pass """), # No doctest required for the ``__init__`` function. Valid(""" def spam(): ''' >>> spam() ''' pass class Bar: # No doctest needed for the init function def __init__(self): pass def bar(self): ''' >>> bar() ''' pass """), ] INVALID = [ Invalid(""" def bar(): pass """), # Only the ``__init__`` function does not require doctest. Invalid(""" def foo(): ''' >>> foo() ''' pass class Spam: def __init__(self): pass def spam(self): pass """), # Check that `_skip_doctest` attribute is reseted after leaving the class. Invalid(""" def bar(): ''' >>> bar() ''' pass class Spam: ''' >>> Spam() ''' def spam(): pass def egg(): pass """), ] def __init__(self, context: CstContext) -> None: super().__init__(context) self._skip_doctest: bool = False self._temporary: bool = False def visit_Module(self, node: cst.Module) -> None: self._skip_doctest = self._has_testnode(node) or self._has_doctest( node) def visit_ClassDef(self, node: cst.ClassDef) -> None: # Temporary storage of the ``skip_doctest`` value only during the class visit. # If the class-level docstring contains doctest, then the checks should only be # skipped for all its methods and not for other functions/class in the module. # After leaving the class, ``skip_doctest`` should be resetted to whatever the # value was before. self._temporary = self._skip_doctest self._skip_doctest = self._has_doctest(node) def leave_ClassDef(self, original_node: cst.ClassDef) -> None: self._skip_doctest = self._temporary def visit_FunctionDef(self, node: cst.FunctionDef) -> None: nodename = node.name.value if nodename != INIT and not self._has_doctest(node): self.report( node, MISSING_DOCTEST.format(filepath=self.context.file_path, nodename=nodename), ) def _has_doctest( self, node: Union[cst.Module, cst.ClassDef, cst.FunctionDef]) -> bool: """Check whether the given node contains doctests. If the ``_skip_doctest`` attribute is ``True``, the function will by default return ``True``, otherwise it will extract the docstring and look for doctest patterns (>>> ) in it. If there is no docstring for the node, this will mean the absence of doctest. """ if not self._skip_doctest: docstring = node.get_docstring() if docstring is not None: for line in docstring.splitlines(): if line.strip().startswith(">>> "): return True return False return True @staticmethod def _has_testnode(node: cst.Module) -> bool: return m.matches( node, m.Module(body=[ # Sequence wildcard matchers matches LibCAST nodes in a row in a # sequence. It does not implicitly match on partial sequences. So, # when matching against a sequence we will need to provide a # complete pattern. This often means using helpers such as # ``ZeroOrMore()`` as the first and last element of the sequence. m.ZeroOrMore(), m.AtLeastN( n=1, matcher=m.OneOf( m.FunctionDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("test_")))), m.ClassDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("Test")))), ), ), m.ZeroOrMore(), ]), )
class ExplicitFrozenDataclassRule(CstLintRule): """ Encourages the use of frozen dataclass objects by telling users to specify the kwarg. Without this lint rule, most users of dataclass won't know to use the kwarg, and may unintentionally end up with mutable objects. """ MESSAGE: str = ( "When using dataclasses, explicitly specify a frozen keyword argument. " + "Example: `@dataclass(frozen=True)` or `@dataclass(frozen=False)`. " + "Docs: https://docs.python.org/3/library/dataclasses.html" ) METADATA_DEPENDENCIES = (QualifiedNameProvider,) VALID = [ Valid( """ @some_other_decorator class Cls: pass """ ), Valid( """ from dataclasses import dataclass @dataclass(frozen=False) class Cls: pass """ ), Valid( """ import dataclasses @dataclasses.dataclass(frozen=False) class Cls: pass """ ), Valid( """ import dataclasses as dc @dc.dataclass(frozen=False) class Cls: pass """ ), Valid( """ from dataclasses import dataclass as dc @dc(frozen=False) class Cls: pass """ ), ] INVALID = [ Invalid( """ from dataclasses import dataclass @some_unrelated_decorator @dataclass # not called as a function @another_unrelated_decorator class Cls: pass """, line=3, expected_replacement=""" from dataclasses import dataclass @some_unrelated_decorator @dataclass(frozen=True) # not called as a function @another_unrelated_decorator class Cls: pass """, ), Invalid( """ from dataclasses import dataclass @dataclass() # called as a function, no kwargs class Cls: pass """, line=2, expected_replacement=""" from dataclasses import dataclass @dataclass(frozen=True) # called as a function, no kwargs class Cls: pass """, ), Invalid( """ from dataclasses import dataclass @dataclass(other_kwarg=False) class Cls: pass """, line=2, expected_replacement=""" from dataclasses import dataclass @dataclass(other_kwarg=False, frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses @dataclasses.dataclass class Cls: pass """, line=2, expected_replacement=""" import dataclasses @dataclasses.dataclass(frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses @dataclasses.dataclass() class Cls: pass """, line=2, expected_replacement=""" import dataclasses @dataclasses.dataclass(frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses @dataclasses.dataclass(other_kwarg=False) class Cls: pass """, line=2, expected_replacement=""" import dataclasses @dataclasses.dataclass(other_kwarg=False, frozen=True) class Cls: pass """, ), Invalid( """ from dataclasses import dataclass as dc @dc class Cls: pass """, line=2, expected_replacement=""" from dataclasses import dataclass as dc @dc(frozen=True) class Cls: pass """, ), Invalid( """ from dataclasses import dataclass as dc @dc() class Cls: pass """, line=2, expected_replacement=""" from dataclasses import dataclass as dc @dc(frozen=True) class Cls: pass """, ), Invalid( """ from dataclasses import dataclass as dc @dc(other_kwarg=False) class Cls: pass """, line=2, expected_replacement=""" from dataclasses import dataclass as dc @dc(other_kwarg=False, frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses as dc @dc.dataclass class Cls: pass """, line=2, expected_replacement=""" import dataclasses as dc @dc.dataclass(frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses as dc @dc.dataclass() class Cls: pass """, line=2, expected_replacement=""" import dataclasses as dc @dc.dataclass(frozen=True) class Cls: pass """, ), Invalid( """ import dataclasses as dc @dc.dataclass(other_kwarg=False) class Cls: pass """, line=2, expected_replacement=""" import dataclasses as dc @dc.dataclass(other_kwarg=False, frozen=True) class Cls: pass """, ), ] def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))
class UseClsInClassmethodRule(CstLintRule): """ Enforces using ``cls`` as the first argument in a ``@classmethod``. """ METADATA_DEPENDENCIES = (QualifiedNameProvider, ScopeProvider) MESSAGE = "When using @classmethod, the first argument must be `cls`." VALID = [ Valid( """ class foo: # classmethod with cls first arg. @classmethod def cm(cls, a, b, c): pass """ ), Valid( """ class foo: # non-classmethod with non-cls first arg. def nm(self, a, b, c): pass """ ), Valid( """ class foo: # staticmethod with non-cls first arg. @staticmethod def sm(a): pass """ ), ] INVALID = [ Invalid( """ class foo: # No args at all. @classmethod def cm(): pass """, expected_replacement=""" class foo: # No args at all. @classmethod def cm(cls): pass """, ), Invalid( """ class foo: # Single arg + reference. @classmethod def cm(a): return a """, expected_replacement=""" class foo: # Single arg + reference. @classmethod def cm(cls): return cls """, ), Invalid( """ class foo: # Another "cls" exists: do not autofix. @classmethod def cm(a): cls = 2 """, ), Invalid( """ class foo: # Multiple args + references. @classmethod async def cm(a, b): b = a b = a.__name__ """, expected_replacement=""" class foo: # Multiple args + references. @classmethod async def cm(cls, b): b = cls b = cls.__name__ """, ), Invalid( """ class foo: # Do not replace in nested scopes. @classmethod async def cm(a, b): b = a b = lambda _: a.__name__ def g(): return a.__name__ # Same-named vars in sub-scopes should not be replaced. b = [a for a in [1,2,3]] def f(a): return a + 1 """, expected_replacement=""" class foo: # Do not replace in nested scopes. @classmethod async def cm(cls, b): b = cls b = lambda _: cls.__name__ def g(): return cls.__name__ # Same-named vars in sub-scopes should not be replaced. b = [a for a in [1,2,3]] def f(a): return a + 1 """, ), Invalid( """ # Do not replace in surrounding scopes. a = 1 class foo: a = 2 def im(a): a = a @classmethod def cm(a): a[1] = foo.cm(a=a) """, expected_replacement=""" # Do not replace in surrounding scopes. a = 1 class foo: a = 2 def im(a): a = a @classmethod def cm(cls): cls[1] = foo.cm(a=cls) """, ), Invalid( """ def another_decorator(x): pass class foo: # Multiple decorators. @another_decorator @classmethod @another_decorator async def cm(a, b, c): pass """, expected_replacement=""" def another_decorator(x): pass class foo: # Multiple decorators. @another_decorator @classmethod @another_decorator async def cm(cls, b, c): pass """, ), ] def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if not any( QualifiedNameProvider.has_name( self, decorator.decorator, QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN), ) for decorator in node.decorators ): return # If it's not a @classmethod, we are not interested. if not node.params.params: # No params, but there must be the 'cls' param. # Note that pyre[47] already catches this, but we also generate # an autofix, so it still makes sense for us to report it here. new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),)) repl = node.with_changes(params=new_params) self.report(node, replacement=repl) return p0_name = node.params.params[0].name if p0_name.value == CLS: return # All good. # Rename all assignments and references of the first param within the # function scope, as long as they are done via a Name node. # We rely on the parser to correctly derive all # assigments and references within the FunctionScope. # The Param node's scope is our classmethod's FunctionScope. scope = self.get_metadata(ScopeProvider, p0_name, None) if not scope: # Cannot autofix without scope metadata. Only report in this case. # Not sure how to repro+cover this in a unit test... # If metadata creation fails, then the whole lint fails, and if it succeeds, # then there is valid metadata. But many other lint rule implementations contain # a defensive scope None check like this one, so I assume it is necessary. self.report(node) return if scope[CLS]: # The scope already has another assignment to "cls". # Trying to rename the first param to "cls" as well may produce broken code. # We should therefore refrain from suggesting an autofix in this case. self.report(node) return refs: List[Union[cst.Name, cst.Attribute]] = [] assignments = scope[p0_name.value] for a in assignments: if isinstance(a, Assignment): assign_node = a.node if isinstance(assign_node, cst.Name): refs.append(assign_node) elif isinstance(assign_node, cst.Param): refs.append(assign_node.name) # There are other types of possible assignment nodes: ClassDef, # FunctionDef, Import, etc. We deliberately do not handle those here. refs += [r.node for r in a.references] repl = node.visit(_RenameTransformer(refs, CLS)) self.report(node, replacement=repl)
class ReplaceUnionWithOptionalRule(CstLintRule): """ Enforces the use of ``Optional[T]`` over ``Union[T, None]`` and ``Union[None, T]``. See https://docs.python.org/3/library/typing.html#typing.Optional to learn more about Optionals. """ MESSAGE: str = "`Optional[T]` is preferred over `Union[T, None]` or `Union[None, T]`. " + "Learn more: https://docs.python.org/3/library/typing.html#typing.Optional" METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider, ) VALID = [ Valid(""" def func() -> Optional[str]: pass """), Valid(""" def func() -> Optional[Dict]: pass """), Valid(""" def func() -> Union[str, int, None]: pass """), ] INVALID = [ Invalid( """ def func() -> Union[str, None]: pass """, ), Invalid( """ from typing import Optional def func() -> Union[Dict[str, int], None]: pass """, expected_replacement=""" from typing import Optional def func() -> Optional[Dict[str, int]]: pass """, ), Invalid( """ from typing import Optional def func() -> Union[str, None]: pass """, expected_replacement=""" from typing import Optional def func() -> Optional[str]: pass """, ), Invalid( """ from typing import Optional def func() -> Union[Dict, None]: pass """, expected_replacement=""" from typing import Optional def func() -> Optional[Dict]: pass """, ), ] def __init__(self, context: CstContext) -> None: super().__init__(context) def leave_Annotation(self, original_node: cst.Annotation) -> None: if self.contains_union_with_none(original_node): scope = self.get_metadata(cst.metadata.ScopeProvider, original_node, None) nones = 0 indexes = [] replacement = None if scope is not None and "Optional" in scope: for s in cst.ensure_type(original_node.annotation, cst.Subscript).slice: if m.matches(s, m.SubscriptElement(m.Index(m.Name("None")))): nones += 1 else: indexes.append(s.slice) if not (nones > 1) and len(indexes) == 1: replacement = original_node.with_changes( annotation=cst.Subscript( value=cst.Name("Optional"), slice=(cst.SubscriptElement(indexes[0]), ), )) # TODO(T57106602) refactor lint replacement once extract exists self.report(original_node, replacement=replacement) def contains_union_with_none(self, node: cst.Annotation) -> bool: return m.matches( node, m.Annotation( m.Subscript( value=m.Name("Union"), slice=m.OneOf( [ m.SubscriptElement(m.Index()), m.SubscriptElement(m.Index(m.Name("None"))), ], [ m.SubscriptElement(m.Index(m.Name("None"))), m.SubscriptElement(m.Index()), ], ), )), )
class GatherSequentialAwaitRule(CstLintRule): """ Discourages awaiting coroutines in a loop as this will run them sequentially. Using ``asyncio.gather()`` will run them concurrently. """ MESSAGE: str = ( "Using await in a loop will run async function sequentially. Use " + "asyncio.gather() to run async functions concurrently.") VALID = [ Valid(""" async def async_foo(): return await async_bar() """), Valid( """ # await in a loop is fine if it's a test. # filename="foo/tests/test_foo.py" async def async_check_call(): for _i in range(0, 2): await async_foo() """, filename="foo/tests/test_foo.py", ), ] INVALID = [ Invalid( """ async def async_check_call(): for _i in range(0, 2): await async_foo() """, line=3, ), Invalid( """ async def async_check_assignment(): for _i in range(0, 2): x = await async_foo() """, line=3, ), Invalid( """ async def async_check_list_comprehension(): [await async_foo() for _i in range(0, 2)] """, line=2, ), ] def should_skip_file(self) -> bool: return self.context.in_tests def visit_Await(self, node: cst.Await) -> None: parent = self.context.node_stack[-2] if isinstance(parent, (cst.Expr, cst.Assign)) and parent.value is node: grand_parent = self.context.node_stack[-5] # for and while code block contain IndentBlock and SimpleStatementLine if isinstance(grand_parent, (cst.For, cst.While)): self.report(node) if (isinstance(parent, (cst.ListComp, cst.SetComp, cst.GeneratorExp)) and parent.elt is node): self.report(node)