def test_multiple_qualified_names(self) -> None: m, names = get_qualified_name_metadata_provider(""" if False: def f(): pass elif False: from b import f else: import f import a.b as f f() """) if_ = ensure_type(m.body[0], cst.If) first_f = ensure_type(if_.body.body[0], cst.FunctionDef) second_f_alias = ensure_type( ensure_type( ensure_type(if_.orelse, cst.If).body.body[0], cst.SimpleStatementLine, ).body[0], cst.ImportFrom, ).names self.assertFalse(isinstance(second_f_alias, cst.ImportStar)) second_f = second_f_alias[0].name third_f_alias = ensure_type( ensure_type( ensure_type(ensure_type(if_.orelse, cst.If).orelse, cst.Else).body.body[0], cst.SimpleStatementLine, ).body[0], cst.Import, ).names self.assertFalse(isinstance(third_f_alias, cst.ImportStar)) third_f = third_f_alias[0].name fourth_f = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import).names[0].asname, cst.AsName, ).name call = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ) self.assertEqual(names[first_f], {QualifiedName("f", QualifiedNameSource.LOCAL)}) self.assertEqual(names[second_f], set()) self.assertEqual(names[third_f], set()) self.assertEqual(names[fourth_f], set()) self.assertEqual( names[call], { QualifiedName("f", QualifiedNameSource.IMPORT), QualifiedName("b.f", QualifiedNameSource.IMPORT), QualifiedName("f", QualifiedNameSource.LOCAL), QualifiedName("a.b", QualifiedNameSource.IMPORT), }, )
def test_simple_qualified_names(self) -> None: m, names = get_qualified_name_metadata_provider( """ from a.b import c class Cls: def f(self) -> "c": c() d = {} d['key'] = 0 def g(): pass g() """ ) cls = ensure_type(m.body[1], cst.ClassDef) f = ensure_type(cls.body.body[0], cst.FunctionDef) self.assertEqual( names[ensure_type(f.returns, cst.Annotation).annotation], set() ) c_call = ensure_type( ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual( names[c_call], {QualifiedName("a.b.c", QualifiedNameSource.IMPORT)} ) self.assertEqual( names[c_call], {QualifiedName("a.b.c", QualifiedNameSource.IMPORT)} ) g_call = ensure_type( ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual(names[g_call], {QualifiedName("g", QualifiedNameSource.LOCAL)}) d_name = ( ensure_type( ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0], cst.Assign ) .targets[0] .target ) self.assertEqual( names[d_name], {QualifiedName("Cls.f.<locals>.d", QualifiedNameSource.LOCAL)}, ) d_subscript = ( ensure_type( ensure_type(f.body.body[2], cst.SimpleStatementLine).body[0], cst.Assign ) .targets[0] .target ) self.assertEqual( names[d_subscript], {QualifiedName("Cls.f.<locals>.d", QualifiedNameSource.LOCAL)}, )
def test_comprehension(self) -> None: m, names = get_qualified_name_metadata_provider( """ class C: def fn(self) -> None: [[k for k in i] for i in [j for j in range(10)]] # Note: # The qualified name of i is straightforward to be "C.fn.<locals>.<comprehension>.i". # ListComp j is evaluated outside of the ListComp i. # so j has qualified name "C.fn.<locals>.<comprehension>.j". # ListComp k is evaluated inside ListComp i. # so k has qualified name "C.fn.<locals>.<comprehension>.<comprehension>.k". """ ) cls_def = ensure_type(m.body[0], cst.ClassDef) fn_def = ensure_type(cls_def.body.body[0], cst.FunctionDef) outer_comp = ensure_type( ensure_type( ensure_type(fn_def.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value, cst.ListComp, ) i = outer_comp.for_in.target self.assertEqual( names[i], { QualifiedName( name="C.fn.<locals>.<comprehension>.i", source=QualifiedNameSource.LOCAL, ) }, ) inner_comp_j = ensure_type(outer_comp.for_in.iter, cst.ListComp) j = inner_comp_j.for_in.target self.assertEqual( names[j], { QualifiedName( name="C.fn.<locals>.<comprehension>.j", source=QualifiedNameSource.LOCAL, ) }, ) inner_comp_k = ensure_type(outer_comp.elt, cst.ListComp) k = inner_comp_k.for_in.target self.assertEqual( names[k], { QualifiedName( name="C.fn.<locals>.<comprehension>.<comprehension>.k", source=QualifiedNameSource.LOCAL, ) }, )
def visit_Call(self, node: cst.Call) -> Optional[bool]: self.test.assertTrue( QualifiedNameProvider.has_name(self, node, "a.b.c")) self.test.assertFalse( QualifiedNameProvider.has_name(self, node, "a.b")) self.test.assertTrue( QualifiedNameProvider.has_name( self, node, QualifiedName("a.b.c", QualifiedNameSource.IMPORT))) self.test.assertFalse( QualifiedNameProvider.has_name( self, node, QualifiedName("a.b.c", QualifiedNameSource.LOCAL)))
def test_nested_qualified_names(self) -> None: m, names = get_qualified_name_metadata_provider( """ class A: def f1(self): def f2(): pass f2() def f3(self): class B(): ... B() def f4(): def f5(): class C: pass C() f5() """ ) cls_a = ensure_type(m.body[0], cst.ClassDef) self.assertEqual(names[cls_a], {QualifiedName("A", QualifiedNameSource.LOCAL)}) func_f1 = ensure_type(cls_a.body.body[0], cst.FunctionDef) self.assertEqual( names[func_f1], {QualifiedName("A.f1", QualifiedNameSource.LOCAL)} ) func_f2_call = ensure_type( ensure_type(func_f1.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual( names[func_f2_call], {QualifiedName("A.f1.<locals>.f2", QualifiedNameSource.LOCAL)}, ) func_f3 = ensure_type(cls_a.body.body[1], cst.FunctionDef) self.assertEqual( names[func_f3], {QualifiedName("A.f3", QualifiedNameSource.LOCAL)} ) call_b = ensure_type( ensure_type(func_f3.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual( names[call_b], {QualifiedName("A.f3.<locals>.B", QualifiedNameSource.LOCAL)} ) func_f4 = ensure_type(m.body[1], cst.FunctionDef) self.assertEqual( names[func_f4], {QualifiedName("f4", QualifiedNameSource.LOCAL)} ) func_f5 = ensure_type(func_f4.body.body[0], cst.FunctionDef) self.assertEqual( names[func_f5], {QualifiedName("f4.<locals>.f5", QualifiedNameSource.LOCAL)} ) cls_c = func_f5.body.body[0] self.assertEqual( names[cls_c], {QualifiedName("f4.<locals>.f5.<locals>.C", QualifiedNameSource.LOCAL)}, )
class ConvertNamedTupleToDataclassCommand(VisitorBasedCodemodCommand): """ Convert NamedTuple class declarations to Python 3.7 dataclasses. This only performs a conversion at the class declaration level. It does not perform type annotation conversions, nor does it convert NamedTuple-specific attributes and methods. """ DESCRIPTION: str = "Convert NamedTuple class declarations to Python 3.7 dataclasses using the @dataclass decorator." METADATA_DEPENDENCIES: Sequence[ProviderT] = (QualifiedNameProvider,) # The 'NamedTuple' we are interested in qualified_namedtuple: QualifiedName = QualifiedName(name="typing.NamedTuple", source=QualifiedNameSource.IMPORT) qualified_dataclass: QualifiedName = QualifiedName(name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT) attr_dataclass: QualifiedName = QualifiedName(name="attr.dataclass", source=QualifiedNameSource.IMPORT) def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if not any( QualifiedNameProvider.has_name( self, decorator.decorator, QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN), ) for decorator in node.decorators ): return # If it's not a @classmethod, we are not interested. if not node.params.params: # No params, but there must be the 'cls' param. # Note that pyre[47] already catches this, but we also generate # an autofix, so it still makes sense for us to report it here. new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),)) repl = node.with_changes(params=new_params) self.report(node, replacement=repl) return p0_name = node.params.params[0].name if p0_name.value == CLS: return # All good. # Rename all assignments and references of the first param within the # function scope, as long as they are done via a Name node. # We rely on the parser to correctly derive all # assigments and references within the FunctionScope. # The Param node's scope is our classmethod's FunctionScope. scope = self.get_metadata(ScopeProvider, p0_name, None) if not scope: # Cannot autofix without scope metadata. Only report in this case. # Not sure how to repro+cover this in a unit test... # If metadata creation fails, then the whole lint fails, and if it succeeds, # then there is valid metadata. But many other lint rule implementations contain # a defensive scope None check like this one, so I assume it is necessary. self.report(node) return if scope[CLS]: # The scope already has another assignment to "cls". # Trying to rename the first param to "cls" as well may produce broken code. # We should therefore refrain from suggesting an autofix in this case. self.report(node) return refs: List[Union[cst.Name, cst.Attribute]] = [] assignments = scope[p0_name.value] for a in assignments: if isinstance(a, Assignment): assign_node = a.node if isinstance(assign_node, cst.Name): refs.append(assign_node) elif isinstance(assign_node, cst.Param): refs.append(assign_node.name) # There are other types of possible assignment nodes: ClassDef, # FunctionDef, Import, etc. We deliberately do not handle those here. refs += [r.node for r in a.references] repl = node.visit(_RenameTransformer(refs, CLS)) self.report(node, replacement=repl)
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))
class RemoveBarTransformer(VisitorBasedCodemodCommand): METADATA_DEPENDENCIES = (QualifiedNameProvider, ScopeProvider) @m.leave( m.SimpleStatementLine(body=[ m.Expr( m.Call(metadata=m.MatchMetadata( QualifiedNameProvider, { QualifiedName( source=QualifiedNameSource.IMPORT, name="foo.bar", ) }, ))) ])) def _leave_foo_bar( self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine, ) -> cst.RemovalSentinel: RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) return cst.RemoveFromParent()
def test_multiple_assignments(self) -> None: m, names = get_qualified_name_metadata_provider(""" if 1: from a import b as c elif 2: from d import e as c c() """) call = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value self.assertEqual( names[call], { QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT), QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT), }, )
def test_with_full_repo_manager(self) -> None: with TemporaryDirectory() as dir: fname = "pkg/mod.py" (Path(dir) / "pkg").mkdir() (Path(dir) / fname).touch() mgr = FullRepoManager(dir, [fname], [FullyQualifiedNameProvider]) wrapper = mgr.get_metadata_wrapper_for_path(fname) fqnames = wrapper.resolve(FullyQualifiedNameProvider) (mod, names) = next(iter(fqnames.items())) self.assertIsInstance(mod, cst.Module) self.assertEqual(names, { QualifiedName(name="pkg.mod", source=QualifiedNameSource.LOCAL) })
def test_local_qualification(self) -> None: base_module = "some.test.module" for (name, expected) in [ (".foo", "some.test.foo"), ("..bar", "some.bar"), ("foo", "some.test.module.foo"), ]: with self.subTest(name=name): self.assertEqual( FullyQualifiedNameVisitor._fully_qualify_local( base_module, QualifiedName(name=name, source=QualifiedNameSource.LOCAL), ), expected, )
def test_name_in_attribute(self) -> None: m, names = get_qualified_name_metadata_provider(""" obj = object() obj.eval """) attr = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Attribute, ) self.assertEqual( names[attr], {QualifiedName(name="obj.eval", source=QualifiedNameSource.LOCAL)}, ) eval = attr.attr self.assertEqual(names[eval], set())
def test_builtins(self) -> None: qnames = get_fully_qualified_names( "test/module.py", """ int(None) """, ) module_name = QualifiedName( name="test.module", source=QualifiedNameSource.LOCAL ) self.assertIn(module_name, qnames) qnames -= {module_name} self.assertEqual( {"builtins.int", "builtins.None"}, {qname.name for qname in qnames}, ) for qname in qnames: self.assertEqual(qname.source, QualifiedNameSource.BUILTIN, msg=f"{qname}")
def test_repeated_values_in_qualified_name(self) -> None: m, names = get_qualified_name_metadata_provider(""" import a class Foo: bar: a.aa.aaa """) foo = ensure_type(m.body[1], cst.ClassDef) bar = ensure_type( ensure_type( ensure_type(foo.body, cst.IndentedBlock).body[0], cst.SimpleStatementLine, ).body[0], cst.AnnAssign, ) annotation = ensure_type(bar.annotation, cst.Annotation) attribute = ensure_type(annotation.annotation, cst.Attribute) self.assertEqual( names[attribute], {QualifiedName("a.aa.aaa", QualifiedNameSource.IMPORT)})
def test_imports(self) -> None: qnames = get_fully_qualified_names( "some/test/module.py", """ from a.b import c as d from . import rel from .lol import rel2 from .. import thing as rel3 d, rel, rel2, rel3 """, ) module_name = QualifiedName( name="some.test.module", source=QualifiedNameSource.LOCAL ) self.assertIn(module_name, qnames) qnames -= {module_name} self.assertEqual( {"a.b.c", "some.test.rel", "some.test.lol.rel2", "some.thing"}, {qname.name for qname in qnames}, ) for qname in qnames: self.assertEqual(qname.source, QualifiedNameSource.IMPORT, msg=f"{qname}")
from typing import Dict, Iterator, List, Set, Tuple import libcst as cst import libcst.matchers as m from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource from fixit import ( CstContext, CstLintRule, InvalidTestCase as Invalid, ValidTestCase as Valid, ) _ISINSTANCE = QualifiedName( name="builtins.isinstance", source=QualifiedNameSource.BUILTIN ) class CollapseIsinstanceChecksRule(CstLintRule): """ The built-in ``isinstance`` function, instead of a single type, can take a tuple of types and check whether given target suits any of them. Rather than chaining multiple ``isinstance`` calls with a boolean-or operation, a single ``isinstance`` call where the second argument is a tuple of all types can be used. """ MESSAGE: str = ( "Multiple isinstance calls with the same target but " + "different types can be collapsed into a single call "
class NoTypedDictRule(CstLintRule): """ Enforce the use of ``dataclasses.dataclass`` decorator instead of ``NamedTuple`` for cleaner customization and inheritance. It supports default value, combining fields for inheritance, and omitting optional fields at instantiation. See `PEP 557 <https://www.python.org/dev/peps/pep-0557>`_. ``@dataclass`` is faster at reading an object's nested properties and executing its methods. (`benchmark <https://medium.com/@jacktator/dataclass-vs-namedtuple-vs-object-for-performance-optimization-in-python-691e234253b9>`_) """ MESSAGE: str = "Instead of TypedDict, consider using the @dataclass decorator from dataclasses instead for simplicity, efficiency and consistency." METADATA_DEPENDENCIES = (QualifiedNameProvider,) VALID = [ Valid( """ @dataclass(frozen=True) class Foo: pass """ ), Valid( """ @dataclass(frozen=False) class Foo: pass """ ), Valid( """ class Foo: pass """ ), Valid( """ class Foo(SomeOtherBase): pass """ ), Valid( """ @some_other_decorator class Foo: pass """ ), Valid( """ @some_other_decorator class Foo(SomeOtherBase): pass """ ), ] INVALID = [ Invalid( code=""" from typing import NamedTuple class Foo(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" from typing import NamedTuple as NT class Foo(NT): pass """, expected_replacement=""" from typing import NamedTuple as NT @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" import typing as typ class Foo(typ.NamedTuple): pass """, expected_replacement=""" import typing as typ @dataclass(frozen=True) class Foo: pass """, ), Invalid( code=""" from typing import NamedTuple class Foo(NamedTuple, AnotherBase, YetAnotherBase): pass """, expected_replacement=""" from typing import NamedTuple @dataclass(frozen=True) class Foo(AnotherBase, YetAnotherBase): pass """, ), Invalid( code=""" from typing import NamedTuple class OuterClass(SomeBase): class InnerClass(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple class OuterClass(SomeBase): @dataclass(frozen=True) class InnerClass: pass """, ), Invalid( code=""" from typing import NamedTuple @some_other_decorator class Foo(NamedTuple): pass """, expected_replacement=""" from typing import NamedTuple @some_other_decorator @dataclass(frozen=True) class Foo: pass """, ), ] qualified_typeddict = QualifiedName(name="typing_extensionsTypedDict", source=QualifiedNameSource.IMPORT) def leave_ClassDef(self, original_node: cst.ClassDef) -> None: (namedtuple_base, new_bases) = self.partition_bases(original_node.bases) if namedtuple_base is not None: call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call) replacement = original_node.with_changes( lpar=MaybeSentinel.DEFAULT, rpar=MaybeSentinel.DEFAULT, bases=new_bases, decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)], ) self.report(original_node, replacement=replacement) def partition_bases(self, original_bases: Sequence[cst.Arg]) -> Tuple[Optional[cst.Arg], List[cst.Arg]]: # Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases namedtuple_base: Optional[cst.Arg] = None new_bases: List[cst.Arg] = [] for base_class in original_bases: if QualifiedNameProvider.has_name(self, base_class.value, self.qualified_typeddict): namedtuple_base = base_class else: new_bases.append(base_class) return (namedtuple_base, new_bases)
class CompareSingletonPrimitivesByIsRule(CstLintRule): """ Enforces the use of `is` and `is not` in comparisons to singleton primitives (None, True, False) rather than == and !=. The == operator checks equality, when in this scenario, we want to check identity. See Flake8 rules E711 (https://www.flake8rules.com/rules/E711.html) and E712 (https://www.flake8rules.com/rules/E712.html). """ MESSAGE: str = "Comparisons to singleton primitives should not be done with == or !=, as they check equality rather than identiy." + " Use `is` or `is not` instead." METADATA_DEPENDENCIES = (QualifiedNameProvider, ) VALID = [ Valid("if x: pass"), Valid("if not x: pass"), Valid("x is True"), Valid("x is False"), Valid("x is None"), Valid("x is not None"), Valid("x is True is not y"), Valid("y is None is not x"), Valid("None is y"), Valid("True is x"), Valid("False is x"), Valid("x == 2"), Valid("2 != x"), ] INVALID = [ Invalid( code="x != True", expected_replacement="x is not True", ), Invalid( code="x != False", expected_replacement="x is not False", ), Invalid( code="x == False", expected_replacement="x is False", ), Invalid( code="x == None", expected_replacement="x is None", ), Invalid( code="x != None", expected_replacement="x is not None", ), Invalid( code="False == x", expected_replacement="False is x", ), Invalid( code="x is True == y", expected_replacement="x is True is y", ), ] QUALIFIED_SINGLETON_PRIMITIVES: FrozenSet[QualifiedName] = frozenset({ QualifiedName(name=f"builtins.{name}", source=QualifiedNameSource.BUILTIN) for name in ("True", "False", "None") }) def visit_Comparison(self, node: cst.Comparison) -> None: # Initialize the needs_report flag as False to begin with needs_report = False left_comp = node.left altered_comparisons = [] for target in node.comparisons: operator, right_comp = target.operator, target.comparator if isinstance( operator, (cst.Equal, cst.NotEqual) ) and (not self.QUALIFIED_SINGLETON_PRIMITIVES.isdisjoint( self.get_metadata(QualifiedNameProvider, left_comp, set())) or not self.QUALIFIED_SINGLETON_PRIMITIVES.isdisjoint( self.get_metadata(QualifiedNameProvider, right_comp, set()))): needs_report = True altered_comparisons.append( target.with_changes( operator=self.alter_operator(operator))) else: altered_comparisons.append(target) # Continue the check down the line of comparisons, if more than one left_comp = right_comp if needs_report: self.report( node, replacement=node.with_changes(comparisons=altered_comparisons)) def alter_operator( self, original_op: Union[cst.Equal, cst.NotEqual]) -> Union[cst.Is, cst.IsNot]: return (cst.IsNot( whitespace_before=original_op.whitespace_before, whitespace_after=original_op.whitespace_after, ) if isinstance(original_op, cst.NotEqual) else cst.Is( whitespace_before=original_op.whitespace_before, whitespace_after=original_op.whitespace_after, ))