def visit_Call(self, node: cst.Call) -> Optional[bool]: self.test.assertTrue( QualifiedNameProvider.has_name(self, node, "a.b.c")) self.test.assertFalse( QualifiedNameProvider.has_name(self, node, "a.b")) self.test.assertTrue( QualifiedNameProvider.has_name( self, node, QualifiedName("a.b.c", QualifiedNameSource.IMPORT))) self.test.assertFalse( QualifiedNameProvider.has_name( self, node, QualifiedName("a.b.c", QualifiedNameSource.LOCAL)))
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if not any( QualifiedNameProvider.has_name( self, decorator.decorator, QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN), ) for decorator in node.decorators ): return # If it's not a @classmethod, we are not interested. if not node.params.params: # No params, but there must be the 'cls' param. # Note that pyre[47] already catches this, but we also generate # an autofix, so it still makes sense for us to report it here. new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),)) repl = node.with_changes(params=new_params) self.report(node, replacement=repl) return p0_name = node.params.params[0].name if p0_name.value == CLS: return # All good. # Rename all assignments and references of the first param within the # function scope, as long as they are done via a Name node. # We rely on the parser to correctly derive all # assigments and references within the FunctionScope. # The Param node's scope is our classmethod's FunctionScope. scope = self.get_metadata(ScopeProvider, p0_name, None) if not scope: # Cannot autofix without scope metadata. Only report in this case. # Not sure how to repro+cover this in a unit test... # If metadata creation fails, then the whole lint fails, and if it succeeds, # then there is valid metadata. But many other lint rule implementations contain # a defensive scope None check like this one, so I assume it is necessary. self.report(node) return if scope[CLS]: # The scope already has another assignment to "cls". # Trying to rename the first param to "cls" as well may produce broken code. # We should therefore refrain from suggesting an autofix in this case. self.report(node) return refs: List[Union[cst.Name, cst.Attribute]] = [] assignments = scope[p0_name.value] for a in assignments: if isinstance(a, Assignment): assign_node = a.node if isinstance(assign_node, cst.Name): refs.append(assign_node) elif isinstance(assign_node, cst.Param): refs.append(assign_node.name) # There are other types of possible assignment nodes: ClassDef, # FunctionDef, Import, etc. We deliberately do not handle those here. refs += [r.node for r in a.references] repl = node.visit(_RenameTransformer(refs, CLS)) self.report(node, replacement=repl)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def leave_Attribute( self, original_node: cst.Attribute, updated_node: cst.Attribute) -> Union[cst.Name, cst.Attribute]: full_name_for_node = get_full_name_for_node(original_node) if full_name_for_node is None: raise Exception("Could not parse full name for Attribute node.") full_replacement_name = self.gen_replacement(full_name_for_node) # If a node has no associated QualifiedName, we are still inside an import statement. inside_import_statement: bool = not self.get_metadata( QualifiedNameProvider, original_node, set()) if (QualifiedNameProvider.has_name( self, original_node, self.old_name, ) or (inside_import_statement and full_replacement_name == self.new_name)): new_value, new_attr = self.new_module, self.new_mod_or_obj if not inside_import_statement: self.scheduled_removals.add(original_node.value) if full_replacement_name == self.new_name: return updated_node.with_changes( value=cst.parse_expression(new_value), attr=cst.Name(value=new_attr.rstrip(".")), ) return self.gen_name_or_attr_node(new_attr) return updated_node
def collect_targets( self, stack: Tuple[cst.BaseExpression, ...] ) -> Tuple[ List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]] ]: targets = {} operands = [] for operand in stack: if m.matches( operand, m.Call(func=m.DoNotCare(), args=[m.Arg(), m.Arg(~m.Tuple())]) ): call = cst.ensure_type(operand, cst.Call) if not QualifiedNameProvider.has_name(self, call, _ISINSTANCE): operands.append(operand) continue target, match = call.args[0].value, call.args[1].value for possible_target in targets: if target.deep_equals(possible_target): targets[possible_target].append(match) break else: operands.append(target) targets[target] = [match] else: operands.append(operand) return operands, targets
def partition_bases(self, original_bases: Sequence[cst.Arg]) -> Tuple[Optional[cst.Arg], List[cst.Arg]]: # Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases namedtuple_base: Optional[cst.Arg] = None new_bases: List[cst.Arg] = [] for base_class in original_bases: if QualifiedNameProvider.has_name(self, base_class.value, self.qualified_typeddict): namedtuple_base = base_class else: new_bases.append(base_class) return (namedtuple_base, new_bases)
def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> Union[cst.Attribute, cst.Name]: full_name_for_node: str = original_node.value full_replacement_name = self.gen_replacement(full_name_for_node) # If a node has no associated QualifiedName, we are still inside an import statement. inside_import_statement: bool = not self.get_metadata( QualifiedNameProvider, original_node, set()) if QualifiedNameProvider.has_name( self, original_node, self.old_name) or (inside_import_statement and full_replacement_name == self.new_name): if not full_replacement_name: full_replacement_name = self.new_name if not inside_import_statement: self.scheduled_removals.add(original_node) return self.gen_name_or_attr_node(full_replacement_name) return updated_node