Exemple #1
0
 def visit_Call(self, node: cst.Call) -> Optional[bool]:
     self.test.assertTrue(
         QualifiedNameProvider.has_name(self, node, "a.b.c"))
     self.test.assertFalse(
         QualifiedNameProvider.has_name(self, node, "a.b"))
     self.test.assertTrue(
         QualifiedNameProvider.has_name(
             self, node,
             QualifiedName("a.b.c", QualifiedNameSource.IMPORT)))
     self.test.assertFalse(
         QualifiedNameProvider.has_name(
             self, node,
             QualifiedName("a.b.c", QualifiedNameSource.LOCAL)))
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                self,
                decorator,
                QualifiedName(
                    name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT
                ),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.
                if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args)
                        + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(value=""),
                                    whitespace_after=SimpleWhitespace(value=""),
                                ),
                            )
                        ],
                    )
                    self.report(d, replacement=d.with_changes(decorator=new_decorator))
    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
        if not any(
            QualifiedNameProvider.has_name(
                self,
                decorator.decorator,
                QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN),
            )
            for decorator in node.decorators
        ):
            return  # If it's not a @classmethod, we are not interested.
        if not node.params.params:
            # No params, but there must be the 'cls' param.
            # Note that pyre[47] already catches this, but we also generate
            # an autofix, so it still makes sense for us to report it here.
            new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),))
            repl = node.with_changes(params=new_params)
            self.report(node, replacement=repl)
            return

        p0_name = node.params.params[0].name
        if p0_name.value == CLS:
            return  # All good.

        # Rename all assignments and references of the first param within the
        # function scope, as long as they are done via a Name node.
        # We rely on the parser to correctly derive all
        # assigments and references within the FunctionScope.
        # The Param node's scope is our classmethod's FunctionScope.
        scope = self.get_metadata(ScopeProvider, p0_name, None)
        if not scope:
            # Cannot autofix without scope metadata. Only report in this case.
            # Not sure how to repro+cover this in a unit test...
            # If metadata creation fails, then the whole lint fails, and if it succeeds,
            # then there is valid metadata. But many other lint rule implementations contain
            # a defensive scope None check like this one, so I assume it is necessary.
            self.report(node)
            return

        if scope[CLS]:
            # The scope already has another assignment to "cls".
            # Trying to rename the first param to "cls" as well may produce broken code.
            # We should therefore refrain from suggesting an autofix in this case.
            self.report(node)
            return

        refs: List[Union[cst.Name, cst.Attribute]] = []
        assignments = scope[p0_name.value]
        for a in assignments:
            if isinstance(a, Assignment):
                assign_node = a.node
                if isinstance(assign_node, cst.Name):
                    refs.append(assign_node)
                elif isinstance(assign_node, cst.Param):
                    refs.append(assign_node.name)
                # There are other types of possible assignment nodes: ClassDef,
                # FunctionDef, Import, etc. We deliberately do not handle those here.
            refs += [r.node for r in a.references]

        repl = node.visit(_RenameTransformer(refs, CLS))
        self.report(node, replacement=repl)
Exemple #4
0
    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
        new_bases: List[cst.Arg] = []
        namedtuple_base: Optional[cst.Arg] = None

        # Need to examine the original node's bases since they are directly tied to import metadata
        for base_class in original_node.bases:
            # Compare the base class's qualified name against the expected typing.NamedTuple
            if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple):
                # Keep all bases that are not of type typing.NamedTuple
                new_bases.append(base_class)
            else:
                namedtuple_base = base_class

        # We still want to return the updated node in case some of its children have been modified
        if namedtuple_base is None:
            return updated_node

        AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass")
        AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass")
        RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value)

        call = cst.ensure_type(
            cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing),
            cst.Call,
        )
        return updated_node.with_changes(
            lpar=cst.MaybeSentinel.DEFAULT,
            rpar=cst.MaybeSentinel.DEFAULT,
            bases=new_bases,
            decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
        )
Exemple #5
0
    def leave_Attribute(
            self, original_node: cst.Attribute,
            updated_node: cst.Attribute) -> Union[cst.Name, cst.Attribute]:
        full_name_for_node = get_full_name_for_node(original_node)
        if full_name_for_node is None:
            raise Exception("Could not parse full name for Attribute node.")
        full_replacement_name = self.gen_replacement(full_name_for_node)

        # If a node has no associated QualifiedName, we are still inside an import statement.
        inside_import_statement: bool = not self.get_metadata(
            QualifiedNameProvider, original_node, set())
        if (QualifiedNameProvider.has_name(
                self,
                original_node,
                self.old_name,
        ) or (inside_import_statement
              and full_replacement_name == self.new_name)):
            new_value, new_attr = self.new_module, self.new_mod_or_obj
            if not inside_import_statement:
                self.scheduled_removals.add(original_node.value)
            if full_replacement_name == self.new_name:
                return updated_node.with_changes(
                    value=cst.parse_expression(new_value),
                    attr=cst.Name(value=new_attr.rstrip(".")),
                )

            return self.gen_name_or_attr_node(new_attr)

        return updated_node
Exemple #6
0
    def collect_targets(
        self, stack: Tuple[cst.BaseExpression, ...]
    ) -> Tuple[
        List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]]
    ]:
        targets = {}
        operands = []

        for operand in stack:
            if m.matches(
                operand, m.Call(func=m.DoNotCare(), args=[m.Arg(), m.Arg(~m.Tuple())])
            ):
                call = cst.ensure_type(operand, cst.Call)
                if not QualifiedNameProvider.has_name(self, call, _ISINSTANCE):
                    operands.append(operand)
                    continue

                target, match = call.args[0].value, call.args[1].value
                for possible_target in targets:
                    if target.deep_equals(possible_target):
                        targets[possible_target].append(match)
                        break
                else:
                    operands.append(target)
                    targets[target] = [match]
            else:
                operands.append(operand)

        return operands, targets
Exemple #7
0
 def partition_bases(self, original_bases: Sequence[cst.Arg]) -> Tuple[Optional[cst.Arg], List[cst.Arg]]:
     # Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases
     namedtuple_base: Optional[cst.Arg] = None
     new_bases: List[cst.Arg] = []
     for base_class in original_bases:
         if QualifiedNameProvider.has_name(self, base_class.value, self.qualified_typeddict):
             namedtuple_base = base_class
         else:
             new_bases.append(base_class)
     return (namedtuple_base, new_bases)
Exemple #8
0
    def leave_Name(self, original_node: cst.Name,
                   updated_node: cst.Name) -> Union[cst.Attribute, cst.Name]:
        full_name_for_node: str = original_node.value
        full_replacement_name = self.gen_replacement(full_name_for_node)

        # If a node has no associated QualifiedName, we are still inside an import statement.
        inside_import_statement: bool = not self.get_metadata(
            QualifiedNameProvider, original_node, set())
        if QualifiedNameProvider.has_name(
                self, original_node,
                self.old_name) or (inside_import_statement
                                   and full_replacement_name == self.new_name):
            if not full_replacement_name:
                full_replacement_name = self.new_name
            if not inside_import_statement:
                self.scheduled_removals.add(original_node)
            return self.gen_name_or_attr_node(full_replacement_name)

        return updated_node