def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def _handle_import(self, node: Union[Import, ImportFrom]) -> None: node_start = self.get_metadata(PositionProvider, node).start.line if node_start in self._ignored_lines: return names = node.names if isinstance(names, ImportStar): return for alias in names: position = self.get_metadata(PositionProvider, alias) lines = set(range(position.start.line, position.end.line + 1)) if lines.isdisjoint(self._ignored_lines): if isinstance(node, Import): RemoveImportsVisitor.remove_unused_import( self.context, module=alias.evaluated_name, asname=alias.evaluated_alias, ) else: module_name = get_absolute_module_for_import( self.context.full_module_name, node) if module_name is None: raise ValueError( f"Couldn't get absolute module name for {alias.evaluated_name}" ) RemoveImportsVisitor.remove_unused_import( self.context, module=module_name, obj=alias.evaluated_name, asname=alias.evaluated_alias, )
def _leave_foo_bar( self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine, ) -> cst.RemovalSentinel: RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) return cst.RemoveFromParent()
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: for removal_node in self.scheduled_removals: RemoveImportsVisitor.remove_unused_import_by_node( self.context, removal_node) # If bypass_import is False, we know that no import statements were directly renamed, and the fact # that we have any `self.scheduled_removals` tells us we encountered a matching `old_name` in the code. if not self.bypass_import and self.scheduled_removals: if self.new_module: new_obj: Optional[str] = (self.new_mod_or_obj.split(".")[0] if self.new_mod_or_obj else None) AddImportsVisitor.add_needed_import(self.context, module=self.new_module, obj=new_obj) return updated_node
def _update_imports(self): RemoveImportsVisitor.remove_unused_import(self.context, "pytz") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC") RemoveImportsVisitor.remove_unused_import( self.context, "datetime", "timezone" ) AddImportsVisitor.add_needed_import( self.context, "bulb.platform.common.timezones", "UTC" )
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Matches calls with symbols without the wx prefix for symbol, matcher, renamed in self.matchers_short_map: if symbol in self.wx_imports and matchers.matches( updated_node, matcher): # Remove the symbol's import RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) # Add import of top level wx package AddImportsVisitor.add_needed_import(self.context, "wx") # Return updated node if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes(func=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value=renamed))) # Matches full calls like wx.MySymbol for matcher, renamed in self.matchers_full_map: if matchers.matches(updated_node, matcher): if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value=renamed))) # Returns updated node return updated_node
def visit_ImportFrom(self, node: ImportFrom) -> bool: RemoveImportsVisitor.remove_unused_import_by_node(self.context, node) return False
def visit_Module(self, node: cst.Module) -> None: AddImportsVisitor.add_needed_import(self.context, "foo", "quux") RemoveImportsVisitor.remove_unused_import( self.context, "foo", "bar")
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: RemoveImportsVisitor.remove_unused_import_by_node( self.context, node)