def leave_ImportFrom( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Union[BaseSmallStatement, RemovalSentinel]: if self._test_import_from(updated_node): new_names = [] new_import_missing = True new_import_alias = None for import_alias in original_node.names: if import_alias.evaluated_name == "url": AddImportsVisitor.add_needed_import( self.context, "django.urls", "re_path", ) else: if import_alias.evaluated_name == "re_path": new_import_missing = False new_names.append(import_alias) if new_import_missing and new_import_alias is not None: new_names.append(new_import_alias) if not new_names: return RemoveFromParent() new_names = list(sorted(new_names, key=lambda n: n.evaluated_name)) return ImportFrom(module=updated_node.module, names=new_names) return super().leave_ImportFrom(original_node, updated_node)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def leave_SimpleString( self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString ) -> Union[libcst.SimpleString, libcst.BaseExpression]: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") return parse_expression( literal_eval(updated_node.value), config=self.module.config_for_parsing )
def leave_ImportFrom( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Union[BaseSmallStatement, RemovalSentinel]: if self._test_import_from(updated_node): new_names = [] for import_alias in updated_node.names: if import_alias.evaluated_name == self.old_name: as_name = ( import_alias.asname.name.value if import_alias.asname else None ) AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_module_parts), obj=self.new_name, asname=as_name, ) else: new_names.append(import_alias) if not new_names: return RemoveFromParent() # sort imports new_names = sorted(new_names, key=lambda n: n.evaluated_name) # remove any trailing commas last_name = new_names[-1] if last_name.comma != MaybeSentinel.DEFAULT: new_names[-1] = last_name.with_changes(comma=MaybeSentinel.DEFAULT) return updated_node.with_changes(names=new_names) return super().leave_ImportFrom(original_node, updated_node)
def _check_import_from_child( self, updated_node: ImportFrom ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]: """ Check import of a member of the module being codemodded. When `parent.module.the_thing` is transformed, detect such import: from parent.module.thing import something """ # First, exit early if 'import *' is used if isinstance(updated_node.names, ImportStar): return None # Check whether a member of the module is imported if not import_from_matches(updated_node, self.old_all_parts): return None # Match, add import for all imported names and remove the existing import for import_alias in updated_node.names: AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_all_parts), obj=import_alias.evaluated_name, asname=import_alias.evaluated_alias, ) return RemoveFromParent()
def get_transforms(self) -> Generator[Type[Codemod], None, None]: AddImportsVisitor.add_needed_import( self.context, self.context.scratch["module"], self.context.scratch["entity"], self.context.scratch["alias"], ) yield AddImportsVisitor
def leave_SimpleString( self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString ) -> Union[libcst.SimpleString, libcst.BaseExpression]: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") # Just use LibCST to evaluate the expression itself, and insert that as the # annotation. return parse_expression( updated_node.evaluated_value, config=self.module.config_for_parsing )
def add_new_import(self, evaluated_name: Optional[str] = None) -> None: as_name = (self.entity_imported_as.name.value if self.entity_imported_as else None) AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_module_parts), obj=self.new_name or evaluated_name, asname=as_name, )
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if m.matches(updated_node, m.Call(func=m.Name("print"))): AddImportsVisitor.add_needed_import( self.context, "pprint", "pprint", ) return updated_node.with_changes(func=Name("pprint")) return super().leave_Call(original_node, updated_node)
def _update_imports(self): RemoveImportsVisitor.remove_unused_import(self.context, "pytz") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC") RemoveImportsVisitor.remove_unused_import( self.context, "datetime", "timezone" ) AddImportsVisitor.add_needed_import( self.context, "bulb.platform.common.timezones", "UTC" )
def _import_annotations_from_future(self) -> None: """We need this because the original sqlalchemy types aren't generic and will fail at runtime.""" LOG.info("Importing necessary annotations...") context = CodemodContext() AddImportsVisitor.add_needed_import(context, "__future__", "annotations") for path in self.paths: source = libcst.parse_module(path.read_text()) modified_tree = AddImportsVisitor(context).transform_module(source) path.write_text(modified_tree.code)
def update_call(self, updated_node: Call) -> BaseExpression: """Update `url` call with either `path` or `re_path`.""" try: return self.update_call_to_path(updated_node) except PatternNotSupported: # Safe fallback to re_path() AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_module_parts), obj=self.new_name, ) return super().update_call(updated_node)
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.Attribute: for matcher in self.matchers: if matchers.matches(updated_node, matcher): # Ensure that wx.adv is imported AddImportsVisitor.add_needed_import(self.context, "wx.adv") # Return modified node return updated_node.with_changes(value=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value="adv"))) return updated_node
def update_call_args(self, node: Call) -> Sequence[Arg]: """Update first argument to convert integer for minutes to timedelta.""" AddImportsVisitor.add_needed_import( context=self.context, module="datetime", obj="timedelta", ) offset_arg, *other_args = node.args integer_value = offset_arg.value if not isinstance(integer_value, Integer): raise AssertionError(f"Unexpected type for: {integer_value}") timedelta_call = parse_expression(f"timedelta(minutes={integer_value.value})") new_offset_arg = offset_arg.with_changes(value=timedelta_call) return (new_offset_arg, *other_args)
def update_call_to_path(self, updated_node: Call) -> Call: """Update an URL pattern to `path()` in simple cases.""" first_arg, *other_args = updated_node.args if not isinstance(first_arg.value, SimpleString): raise PatternNotSupported() # Extract the URL pattern from the first argument pattern = first_arg.value.evaluated_value # If we reach this point, we might be able to use `path()` call = self.build_path_call(pattern, other_args) AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_module_parts), obj="path", ) return call
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: for removal_node in self.scheduled_removals: RemoveImportsVisitor.remove_unused_import_by_node( self.context, removal_node) # If bypass_import is False, we know that no import statements were directly renamed, and the fact # that we have any `self.scheduled_removals` tells us we encountered a matching `old_name` in the code. if not self.bypass_import and self.scheduled_removals: if self.new_module: new_obj: Optional[str] = (self.new_mod_or_obj.split(".")[0] if self.new_mod_or_obj else None) AddImportsVisitor.add_needed_import(self.context, module=self.new_module, obj=new_obj) return updated_node
def update_call_to_path(self, updated_node: Call): """Update an URL pattern to `path()` in simple cases.""" first_arg, *other_args = updated_node.args self.check_not_simple_string(first_arg) # Extract the URL pattern from the first argument pattern = first_arg.value.evaluated_value self.check_missing_start(pattern) # If we reach this point, we might be able to use `path()` call = self.build_path_call(pattern, other_args) AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_module_parts), obj="path", ) return call
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if (is_one_to_one_field(original_node) or is_foreign_key(original_node) ) and not has_on_delete(original_node): AddImportsVisitor.add_needed_import( context=self.context, module="django.db", obj="models", ) updated_args = ( *updated_node.args, parse_arg("on_delete=models.CASCADE"), ) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
def _check_import_from_parent( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]: """ Check for when the parent module of thing to replace is imported. When `parent.module.the_thing` is transformed, detect such import: from parent import module """ # First, exit early if 'import *' is used if isinstance(updated_node.names, ImportStar): return None # Check whether parent module is imported if not import_from_matches(updated_node, self.old_parent_module_parts): return None # Match, update the node an return it new_import_aliases = [] for import_alias in updated_node.names: if import_alias.evaluated_name == self.old_parent_name: self.save_import_scope(original_node) module_name_str = (import_alias.evaluated_alias or import_alias.evaluated_name) self.context.scratch[self.ctx_key_name_matcher] = m.Attribute( value=m.Name(module_name_str), attr=m.Name(self.old_name), ) self.context.scratch[self.ctx_key_new_func] = Attribute( attr=Name(self.new_name), value=Name(import_alias.evaluated_alias or self.new_parent_name), ) if self.old_parent_module_parts != self.new_parent_module_parts: # import statement needs updating AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_parent_module_parts), obj=self.new_parent_name, asname=import_alias.evaluated_alias, ) continue new_import_aliases.append(import_alias) if not new_import_aliases: # Nothing left in the import statement: remove it return RemoveFromParent() # Some imports are left, update the statement new_import_aliases = clean_new_import_aliases(new_import_aliases) return updated_node.with_changes(names=new_import_aliases)
def leave_FunctionDef( self, original_node: FunctionDef, updated_node: FunctionDef ) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]: if self.visiting_permalink_method: for decorator in updated_node.decorators: if m.matches(decorator, self.decorator_matcher): AddImportsVisitor.add_needed_import( context=self.context, module="django.urls", obj="reverse", ) updated_decorators = list(updated_node.decorators) updated_decorators.remove(decorator) self.context.scratch.pop(self.ctx_key_inside_method, None) return updated_node.with_changes( decorators=tuple(updated_decorators)) return super().leave_FunctionDef(original_node, updated_node)
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Matches calls with symbols without the wx prefix for symbol, matcher, renamed in self.matchers_short_map: if symbol in self.wx_imports and matchers.matches( updated_node, matcher): # Remove the symbol's import RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) # Add import of top level wx package AddImportsVisitor.add_needed_import(self.context, "wx") # Return updated node if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes(func=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value=renamed))) # Matches full calls like wx.MySymbol for matcher, renamed in self.matchers_full_map: if matchers.matches(updated_node, matcher): if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value=renamed))) # Returns updated node return updated_node
def visit_Module(self, node: cst.Module) -> None: AddImportsVisitor.add_needed_import(self.context, "foo", "quux") RemoveImportsVisitor.remove_unused_import( self.context, "foo", "bar")