def leave_ImportFrom(
     self, original_node: ImportFrom, updated_node: ImportFrom
 ) -> Union[BaseSmallStatement, RemovalSentinel]:
     if self._test_import_from(updated_node):
         new_names = []
         new_import_missing = True
         new_import_alias = None
         for import_alias in original_node.names:
             if import_alias.evaluated_name == "url":
                 AddImportsVisitor.add_needed_import(
                     self.context,
                     "django.urls",
                     "re_path",
                 )
             else:
                 if import_alias.evaluated_name == "re_path":
                     new_import_missing = False
                 new_names.append(import_alias)
         if new_import_missing and new_import_alias is not None:
             new_names.append(new_import_alias)
         if not new_names:
             return RemoveFromParent()
         new_names = list(sorted(new_names, key=lambda n: n.evaluated_name))
         return ImportFrom(module=updated_node.module, names=new_names)
     return super().leave_ImportFrom(original_node, updated_node)
Exemple #2
0
    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
        new_bases: List[cst.Arg] = []
        namedtuple_base: Optional[cst.Arg] = None

        # Need to examine the original node's bases since they are directly tied to import metadata
        for base_class in original_node.bases:
            # Compare the base class's qualified name against the expected typing.NamedTuple
            if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple):
                # Keep all bases that are not of type typing.NamedTuple
                new_bases.append(base_class)
            else:
                namedtuple_base = base_class

        # We still want to return the updated node in case some of its children have been modified
        if namedtuple_base is None:
            return updated_node

        AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass")
        AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass")
        RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value)

        call = cst.ensure_type(
            cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing),
            cst.Call,
        )
        return updated_node.with_changes(
            lpar=cst.MaybeSentinel.DEFAULT,
            rpar=cst.MaybeSentinel.DEFAULT,
            bases=new_bases,
            decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
        )
 def leave_SimpleString(
     self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString
 ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
     AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations")
     return parse_expression(
         literal_eval(updated_node.value), config=self.module.config_for_parsing
     )
Exemple #4
0
 def leave_ImportFrom(
     self, original_node: ImportFrom, updated_node: ImportFrom
 ) -> Union[BaseSmallStatement, RemovalSentinel]:
     if self._test_import_from(updated_node):
         new_names = []
         for import_alias in updated_node.names:
             if import_alias.evaluated_name == self.old_name:
                 as_name = (
                     import_alias.asname.name.value if import_alias.asname else None
                 )
                 AddImportsVisitor.add_needed_import(
                     context=self.context,
                     module=".".join(self.new_module_parts),
                     obj=self.new_name,
                     asname=as_name,
                 )
             else:
                 new_names.append(import_alias)
         if not new_names:
             return RemoveFromParent()
         # sort imports
         new_names = sorted(new_names, key=lambda n: n.evaluated_name)
         # remove any trailing commas
         last_name = new_names[-1]
         if last_name.comma != MaybeSentinel.DEFAULT:
             new_names[-1] = last_name.with_changes(comma=MaybeSentinel.DEFAULT)
         return updated_node.with_changes(names=new_names)
     return super().leave_ImportFrom(original_node, updated_node)
Exemple #5
0
    def _check_import_from_child(
        self, updated_node: ImportFrom
    ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]:
        """
        Check import of a member of the module being codemodded.

        When `parent.module.the_thing` is transformed, detect such import:

            from parent.module.thing import something
        """
        # First, exit early if 'import *' is used
        if isinstance(updated_node.names, ImportStar):
            return None
        # Check whether a member of the module is imported
        if not import_from_matches(updated_node, self.old_all_parts):
            return None
        # Match, add import for all imported names and remove the existing import
        for import_alias in updated_node.names:
            AddImportsVisitor.add_needed_import(
                context=self.context,
                module=".".join(self.new_all_parts),
                obj=import_alias.evaluated_name,
                asname=import_alias.evaluated_alias,
            )
        return RemoveFromParent()
 def get_transforms(self) -> Generator[Type[Codemod], None, None]:
     AddImportsVisitor.add_needed_import(
         self.context,
         self.context.scratch["module"],
         self.context.scratch["entity"],
         self.context.scratch["alias"],
     )
     yield AddImportsVisitor
Exemple #7
0
 def leave_SimpleString(
     self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString
 ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
     AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations")
     # Just use LibCST to evaluate the expression itself, and insert that as the
     # annotation.
     return parse_expression(
         updated_node.evaluated_value, config=self.module.config_for_parsing
     )
Exemple #8
0
 def add_new_import(self, evaluated_name: Optional[str] = None) -> None:
     as_name = (self.entity_imported_as.name.value
                if self.entity_imported_as else None)
     AddImportsVisitor.add_needed_import(
         context=self.context,
         module=".".join(self.new_module_parts),
         obj=self.new_name or evaluated_name,
         asname=as_name,
     )
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if m.matches(updated_node, m.Call(func=m.Name("print"))):
         AddImportsVisitor.add_needed_import(
             self.context,
             "pprint",
             "pprint",
         )
         return updated_node.with_changes(func=Name("pprint"))
     return super().leave_Call(original_node, updated_node)
Exemple #10
0
	def _update_imports(self):
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC")
		RemoveImportsVisitor.remove_unused_import(
				self.context, "datetime", "timezone"
		)
		AddImportsVisitor.add_needed_import(
				self.context, "bulb.platform.common.timezones", "UTC"
		)
Exemple #11
0
 def _import_annotations_from_future(self) -> None:
     """We need this because the original sqlalchemy types aren't generic
     and will fail at runtime."""
     LOG.info("Importing necessary annotations...")
     context = CodemodContext()
     AddImportsVisitor.add_needed_import(context, "__future__",
                                         "annotations")
     for path in self.paths:
         source = libcst.parse_module(path.read_text())
         modified_tree = AddImportsVisitor(context).transform_module(source)
         path.write_text(modified_tree.code)
 def update_call(self, updated_node: Call) -> BaseExpression:
     """Update `url` call with either `path` or `re_path`."""
     try:
         return self.update_call_to_path(updated_node)
     except PatternNotSupported:
         # Safe fallback to re_path()
         AddImportsVisitor.add_needed_import(
             context=self.context,
             module=".".join(self.new_module_parts),
             obj=self.new_name,
         )
         return super().update_call(updated_node)
Exemple #13
0
    def leave_Attribute(self, original_node: cst.Attribute,
                        updated_node: cst.Attribute) -> cst.Attribute:
        for matcher in self.matchers:
            if matchers.matches(updated_node, matcher):
                # Ensure that wx.adv is imported
                AddImportsVisitor.add_needed_import(self.context, "wx.adv")

                # Return modified node
                return updated_node.with_changes(value=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value="adv")))

        return updated_node
Exemple #14
0
 def update_call_args(self, node: Call) -> Sequence[Arg]:
     """Update first argument to convert integer for minutes to timedelta."""
     AddImportsVisitor.add_needed_import(
         context=self.context,
         module="datetime",
         obj="timedelta",
     )
     offset_arg, *other_args = node.args
     integer_value = offset_arg.value
     if not isinstance(integer_value, Integer):
         raise AssertionError(f"Unexpected type for: {integer_value}")
     timedelta_call = parse_expression(f"timedelta(minutes={integer_value.value})")
     new_offset_arg = offset_arg.with_changes(value=timedelta_call)
     return (new_offset_arg, *other_args)
Exemple #15
0
 def update_call_to_path(self, updated_node: Call) -> Call:
     """Update an URL pattern to `path()` in simple cases."""
     first_arg, *other_args = updated_node.args
     if not isinstance(first_arg.value, SimpleString):
         raise PatternNotSupported()
     # Extract the URL pattern from the first argument
     pattern = first_arg.value.evaluated_value
     # If we reach this point, we might be able to use `path()`
     call = self.build_path_call(pattern, other_args)
     AddImportsVisitor.add_needed_import(
         context=self.context,
         module=".".join(self.new_module_parts),
         obj="path",
     )
     return call
Exemple #16
0
 def leave_Module(self, original_node: cst.Module,
                  updated_node: cst.Module) -> cst.Module:
     for removal_node in self.scheduled_removals:
         RemoveImportsVisitor.remove_unused_import_by_node(
             self.context, removal_node)
     # If bypass_import is False, we know that no import statements were directly renamed, and the fact
     # that we have any `self.scheduled_removals` tells us we encountered a matching `old_name` in the code.
     if not self.bypass_import and self.scheduled_removals:
         if self.new_module:
             new_obj: Optional[str] = (self.new_mod_or_obj.split(".")[0]
                                       if self.new_mod_or_obj else None)
             AddImportsVisitor.add_needed_import(self.context,
                                                 module=self.new_module,
                                                 obj=new_obj)
     return updated_node
Exemple #17
0
 def update_call_to_path(self, updated_node: Call):
     """Update an URL pattern to `path()` in simple cases."""
     first_arg, *other_args = updated_node.args
     self.check_not_simple_string(first_arg)
     # Extract the URL pattern from the first argument
     pattern = first_arg.value.evaluated_value
     self.check_missing_start(pattern)
     # If we reach this point, we might be able to use `path()`
     call = self.build_path_call(pattern, other_args)
     AddImportsVisitor.add_needed_import(
         context=self.context,
         module=".".join(self.new_module_parts),
         obj="path",
     )
     return call
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if (is_one_to_one_field(original_node) or is_foreign_key(original_node)
         ) and not has_on_delete(original_node):
         AddImportsVisitor.add_needed_import(
             context=self.context,
             module="django.db",
             obj="models",
         )
         updated_args = (
             *updated_node.args,
             parse_arg("on_delete=models.CASCADE"),
         )
         return updated_node.with_changes(args=updated_args)
     return super().leave_Call(original_node, updated_node)
Exemple #19
0
    def _check_import_from_parent(
        self, original_node: ImportFrom, updated_node: ImportFrom
    ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]:
        """
        Check for when the parent module of thing to replace is imported.

        When `parent.module.the_thing` is transformed, detect such import:

            from parent import module
        """
        # First, exit early if 'import *' is used
        if isinstance(updated_node.names, ImportStar):
            return None
        # Check whether parent module is imported
        if not import_from_matches(updated_node, self.old_parent_module_parts):
            return None
        # Match, update the node an return it
        new_import_aliases = []
        for import_alias in updated_node.names:
            if import_alias.evaluated_name == self.old_parent_name:
                self.save_import_scope(original_node)
                module_name_str = (import_alias.evaluated_alias
                                   or import_alias.evaluated_name)
                self.context.scratch[self.ctx_key_name_matcher] = m.Attribute(
                    value=m.Name(module_name_str),
                    attr=m.Name(self.old_name),
                )
                self.context.scratch[self.ctx_key_new_func] = Attribute(
                    attr=Name(self.new_name),
                    value=Name(import_alias.evaluated_alias
                               or self.new_parent_name),
                )
                if self.old_parent_module_parts != self.new_parent_module_parts:
                    # import statement needs updating
                    AddImportsVisitor.add_needed_import(
                        context=self.context,
                        module=".".join(self.new_parent_module_parts),
                        obj=self.new_parent_name,
                        asname=import_alias.evaluated_alias,
                    )
                    continue
            new_import_aliases.append(import_alias)
        if not new_import_aliases:
            # Nothing left in the import statement: remove it
            return RemoveFromParent()
        # Some imports are left, update the statement
        new_import_aliases = clean_new_import_aliases(new_import_aliases)
        return updated_node.with_changes(names=new_import_aliases)
 def leave_FunctionDef(
     self, original_node: FunctionDef, updated_node: FunctionDef
 ) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]:
     if self.visiting_permalink_method:
         for decorator in updated_node.decorators:
             if m.matches(decorator, self.decorator_matcher):
                 AddImportsVisitor.add_needed_import(
                     context=self.context,
                     module="django.urls",
                     obj="reverse",
                 )
                 updated_decorators = list(updated_node.decorators)
                 updated_decorators.remove(decorator)
                 self.context.scratch.pop(self.ctx_key_inside_method, None)
                 return updated_node.with_changes(
                     decorators=tuple(updated_decorators))
     return super().leave_FunctionDef(original_node, updated_node)
Exemple #21
0
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Matches calls with symbols without the wx prefix
        for symbol, matcher, renamed in self.matchers_short_map:
            if symbol in self.wx_imports and matchers.matches(
                    updated_node, matcher):
                # Remove the symbol's import
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)

                # Add import of top level wx package
                AddImportsVisitor.add_needed_import(self.context, "wx")

                # Return updated node
                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(func=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value=renamed)))

        # Matches full calls like wx.MySymbol
        for matcher, renamed in self.matchers_full_map:
            if matchers.matches(updated_node, matcher):

                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(
                    func=updated_node.func.with_changes(attr=cst.Name(
                        value=renamed)))

        # Returns updated node
        return updated_node
 def visit_Module(self, node: cst.Module) -> None:
     AddImportsVisitor.add_needed_import(self.context, "foo",
                                         "quux")
     RemoveImportsVisitor.remove_unused_import(
         self.context, "foo", "bar")