示例#1
0
 def visit_Assign(self, node: Assign) -> Optional[bool]:
     """Record variable name the `Library()` call is assigned to."""
     if self.library_call_matcher and m.matches(
         node,
         m.Assign(value=self.library_call_matcher),
     ):
         # Visiting a `register = template.Library()` statement
         # Get all names on the left side of the assignment
         target_names = (
             assign_target.target.value for assign_target in node.targets
         )
         # Build the decorator matchers to look out for
         target_matchers = (
             m.Decorator(
                 decorator=m.Attribute(
                     value=m.Name(name),
                     attr=m.Name("assignment_tag"),
                 )
             )
             for name in target_names
         )
         # The final matcher should match if any of the decorator matchers matches
         self.context.scratch[self.ctx_key_decorator_matcher] = m.OneOf(
             *target_matchers
         )
     return super().visit_Assign(node)
示例#2
0
 def visit_Assign(self, node) -> None:
     if (m.matches(node, m.Assign(targets=[m.AssignTarget(m.Name())]))
             and self.toplevel == 0):
         name = node.targets[0].target
         self.imprts[name.value] = cst.ImportFrom(
             module=parse_expr(self.mod),
             names=[cst.ImportAlias(name=name, asname=None)])
示例#3
0
 def test_or_operator_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(matches(cst.Name("None"), m.Name("True") | m.Name("False")))
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("None")),
             m.Assign(value=m.Name("True") | m.Name("False")),
         )
     )
示例#4
0
 def visit_Assign(self, node: cst.Assign) -> None:
     d = m.extract(
         node,
         m.Assign(
             targets=(m.AssignTarget(target=m.Name("CARBON_EXTS")), ),
             value=m.SaveMatchedNode(m.List(), "list"),
         ),
     )
     if d:
         assert isinstance(d["list"], cst.List)
         for item in d["list"].elements:
             if isinstance(item.value, cst.SimpleString):
                 self.extension_names.append(item.value.evaluated_value)
示例#5
0
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
                                  updated_node: cst.SimpleStatementLine):
        if match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.Assign(targets=[
                        match.AssignTarget(target=match.Name(
                            value=match.DoNotCare()))
                    ])
                ])):
            t = self.__get_var_type_assign_t(
                original_node.body[0].targets[0].target.value)

            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(
                            target=original_node.body[0].targets[0].target,
                            value=original_node.body[0].value,
                            annotation=t_annot_node,
                            equal=cst.AssignEqual(
                                whitespace_after=original_node.body[0].
                                targets[0].whitespace_after_equal,
                                whitespace_before=original_node.body[0].
                                targets[0].whitespace_before_equal))
                    ])
        elif match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.AnnAssign(target=match.Name(value=match.DoNotCare()))
                ])):
            t = self.__get_var_type_an_assign(
                original_node.body[0].target.value)
            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(target=original_node.body[0].target,
                                      value=original_node.body[0].value,
                                      annotation=t_annot_node,
                                      equal=original_node.body[0].equal)
                    ])

        return original_node
示例#6
0
 def visit_Assign(self, node: Assign) -> Optional[bool]:
     """Record variable name the `Library()` call is assigned to."""
     if self.library_call_matcher and m.matches(
         node,
         m.Assign(value=self.library_call_matcher),
     ):
         # Visiting a `register = template.Library()` statement
         # Generate decorator matchers based on left hand side names
         decorator_matchers = self._gen_decorator_matchers(node.targets)
         # should match if any of the decorator matches
         self.context.scratch[self.ctx_key_decorator_matcher] = m.OneOf(
             *decorator_matchers
         )
     return super().visit_Assign(node)
示例#7
0
    def visit_simple_stmt(self, node: SimpleStatementLine) -> None:
        assign = None
        for c in node.children:
            if m.matches(c, m.Assign()):
                assign = ensure_type(c, Assign)
        if assign:
            if m.MatchIfTrue(
                lambda n: n.trailing_whitespace.comment
                and "type:" in n.trailing_whitespace.comment.value
            ):

                class TypingVisitor(m.MatcherDecoratableVisitor):
                    def __init__(self):
                        super().__init__()
                        self.vtype = None

                    def visit_TrailingWhitespace_comment(
                        self, node: "TrailingWhitespace"
                    ) -> None:
                        if node.comment:
                            mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                            if mo:
                                vtype = mo.group(1)
                        return None

                tv = TypingVisitor()
                node.visit(tv)
                vtype = tv.vtype
            else:
                vtype = None

            class NameVisitor(m.MatcherDecoratableVisitor):
                def __init__(self):
                    super().__init__()
                    self.names: List[str] = []

                def visit_Name(self, node: Name) -> Optional[bool]:
                    self.names.append(node.value)
                    return None

            if self.verbose:
                pos = self.get_metadata(PositionProvider, node).start
                for target in assign.targets:
                    v = NameVisitor()
                    target.visit(v)
                    for name in v.names:
                        print(
                            f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}"
                        )
示例#8
0
    def _split_module(
        self, orig_module: libcst.Module, updated_module: libcst.Module
    ) -> Tuple[List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]],
               List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]], List[Union[
                              libcst.SimpleStatementLine,
                              libcst.BaseCompoundStatement]], ]:
        statement_before_import_location = 0
        import_add_location = 0

        # never insert an import before initial __strict__ flag
        if m.matches(
                orig_module,
                m.Module(body=[
                    m.SimpleStatementLine(body=[
                        m.Assign(targets=[
                            m.AssignTarget(target=m.Name("__strict__"))
                        ])
                    ]),
                    m.ZeroOrMore(),
                ]),
        ):
            statement_before_import_location = import_add_location = 1

        # This works under the principle that while we might modify node contents,
        # we have yet to modify the number of statements. So we can match on the
        # original tree but break up the statements of the modified tree. If we
        # change this assumption in this visitor, we will have to change this code.
        for i, statement in enumerate(orig_module.body):
            if m.matches(
                    statement,
                    m.SimpleStatementLine(
                        body=[m.Expr(value=m.SimpleString())])):
                statement_before_import_location = import_add_location = 1
            elif isinstance(statement, libcst.SimpleStatementLine):
                for possible_import in statement.body:
                    for last_import in self.all_imports:
                        if possible_import is last_import:
                            import_add_location = i + 1
                            break

        return (
            list(updated_module.body[:statement_before_import_location]),
            list(updated_module.
                 body[statement_before_import_location:import_add_location]),
            list(updated_module.body[import_add_location:]),
        )
示例#9
0
 def test_or_operator_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(matches(cst.Name("True"), m.Name("True") | m.Name("False")))
     # Match on either True or False identifier.
     self.assertTrue(matches(cst.Name("False"), m.Name("True") | m.Name("False")))
     # Match on either True, False or None identifier.
     self.assertTrue(
         matches(cst.Name("None"), m.Name("True") | m.Name("False") | m.Name("None"))
     )
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("True")),
             m.Assign(value=m.Name("True") | m.Name("False")),
         )
     )
示例#10
0
 def test_or_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(
         matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False")))
     )
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             libcst.Assign(
                 (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None")
             ),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         )
     )
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("4")),
                         m.Arg(m.Integer("5")),
                         m.Arg(m.Integer("6")),
                     ),
                 ),
             ),
         )
     )
示例#11
0
 def test_or_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(
         matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False")))
     )
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             libcst.Assign(
                 (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("True")
             ),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         )
     )
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         )
     )
示例#12
0
    def leave_Assign(self, original_node, updated_node):
        if any([
                m.matches(
                    updated_node,
                    m.Assign(targets=[m.AssignTarget(m.Name())],
                             value=pattern)) for pattern in self.rhs_patterns
        ]):
            var = original_node.targets[0].target
            scope = self.get_metadata(ScopeProvider, var)
            children = self._scope_children[scope]

            if len(scope.assignments[var]) == 1:
                valid_scopes = [scope] + [
                    child
                    for child in children if len(child.assignments[var]) == 0
                ]
                self.propagate(valid_scopes, var, updated_node.value)
                return cst.RemoveFromParent()

        return updated_node
示例#13
0
class VersionTransformer(m.MatcherDecoratableTransformer):
    new_version: Union[None, Version] = None

    def __init__(self, version_mod: Callable[[Version], Version]):
        super().__init__()
        self.version_mod = version_mod

    @m.call_if_inside(
        m.Assign(
            targets=[m.AssignTarget(target=m.Name("__version__"))],
            value=m.SimpleString(),
        ))
    @m.leave(m.SimpleString())
    def update_version(self, original_node: cst.SimpleString,
                       updated_node: cst.SimpleString) -> cst.SimpleString:
        if self.new_version:
            raise Exception("Multiple versions found.")

        old_version = Version(updated_node.evaluated_value)
        self.new_version = self.version_mod(old_version)
        return updated_node.with_changes(value=f'"{self.new_version}"')
示例#14
0
    def __extract_assign_newtype(self, node: cst.Assign):
        """
        Attempts extracting a NewType declaration from the provided Assign node.

        If the Assign node corresponds to a NewType assignment, the NewType name is
        added to the class definitions of the Visitor.
        """
        # Define matcher to extract NewType assignment
        matcher_newtype = match.Assign(
            targets=[  # Check the assign targets
                match.AssignTarget(  # There should only be one target
                    target=match.Name(  # Check target name
                        value=match.SaveMatchedNode(  # Save target name
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "type")))
            ],
            value=match.Call(  # We are examining a function call
                func=match.Name(  # Function must have a name
                    value="NewType"  # Name must be 'NewType'
                ),
                args=[
                    match.Arg(  # Check first argument
                        value=match.SimpleString(
                        )  # First argument must be the name for the type
                    ),
                    match.ZeroOrMore(
                    )  # We allow any number of arguments after by def. of NewType
                ]))

        extracted_type = match.extract(node, matcher_newtype)

        if extracted_type is not None:
            # Append the additional type to the list
            # TODO: Either rename class defs, or create new list for additional types
            self.class_defs.append(extracted_type["type"].strip("\'"))
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        doc_string = node.get_docstring()
        if not doc_string or "@sorted-attributes" not in doc_string:
            return

        found_any_assign: bool = False
        pre_assign_lines: List[LineType] = []
        assign_lines: List[LineType] = []
        post_assign_lines: List[LineType] = []

        def _add_unmatched_line(line: LineType) -> None:
            post_assign_lines.append(
                line) if found_any_assign else pre_assign_lines.append(line)

        for line in node.body.body:
            if m.matches(
                    line,
                    m.SimpleStatementLine(
                        body=[m.Assign(targets=[m.AssignTarget()])])):
                found_any_assign = True
                assign_lines.append(line)
            else:
                _add_unmatched_line(line)
                continue

        sorted_assign_lines = sorted(
            assign_lines,
            key=lambda line: line.body[0].targets[0].target.value)
        if sorted_assign_lines == assign_lines:
            return
        self.report(
            node,
            replacement=node.with_changes(body=node.body.with_changes(
                body=pre_assign_lines + sorted_assign_lines +
                post_assign_lines)),
        )
示例#16
0
                                    newline=m.Newline())

_django_model_field_name_value = m.Call(func=m.Attribute(
    attr=m.Name(m.MatchIfTrue(is_model_field_type)))) | m.Call(
        func=m.Name(m.MatchIfTrue(is_model_field_type)))

_django_model_field_name_with_leading_comment_value = m.Call(
    func=m.Attribute(attr=m.Name(m.MatchIfTrue(is_model_field_type))),
    whitespace_before_args=m.ParenthesizedWhitespace(_any_comment),
) | m.Call(
    func=m.Name(m.MatchIfTrue(is_model_field_type)),
    whitespace_before_args=m.ParenthesizedWhitespace(_any_comment),
)

_django_model_field_with_leading_comment = m.SimpleStatementLine(body=[
    m.Assign(value=_django_model_field_name_with_leading_comment_value)
    | m.AnnAssign(value=_django_model_field_name_with_leading_comment_value)
])

_django_model_field_with_trailing_comment = m.SimpleStatementLine(
    body=[
        m.Assign(value=_django_model_field_name_value)
        | m.AnnAssign(value=_django_model_field_name_value)
    ],
    trailing_whitespace=_any_comment,
)

django_model_field_with_comments = (_django_model_field_with_leading_comment |
                                    _django_model_field_with_trailing_comment)

示例#17
0
 def _is_property_fset(self, assgn):
     return m.matches(assgn,
                      m.Assign(targets=[m.AssignTarget(m.Attribute())]))
class TypeCollector(m.MatcherDecoratableVisitor):
    """
    Collect type annotations from a stub module.
    """

    METADATA_DEPENDENCIES = (
        PositionProvider,
        QualifiedNameProvider,
    )

    annotations: Annotations

    def __init__(
        self,
        existing_imports: Set[str],
        context: CodemodContext,
    ) -> None:
        super().__init__()
        self.context = context
        # Existing imports, determined by looking at the target module.
        # Used to help us determine when a type in a stub will require new imports.
        #
        # The contents of this are fully-qualified names of types in scope
        # as well as module names, although downstream we effectively ignore
        # the module names as of the current implementation.
        self.existing_imports: Set[str] = existing_imports
        # Fields that help us track temporary state as we recurse
        self.qualifier: List[str] = []
        self.current_assign: Optional[
            cst.Assign] = None  # used to collect typevars
        # Store the annotations.
        self.annotations = Annotations.empty()

    def visit_ClassDef(
        self,
        node: cst.ClassDef,
    ) -> None:
        self.qualifier.append(node.name.value)
        new_bases = []
        for base in node.bases:
            value = base.value
            if isinstance(value, NAME_OR_ATTRIBUTE):
                new_value = self._handle_NameOrAttribute(value)
            elif isinstance(value, cst.Subscript):
                new_value = self._handle_Subscript(value)
            else:
                start = self.get_metadata(PositionProvider, node).start
                raise ValueError(
                    "Invalid type used as base class in stub file at " +
                    f"{start.line}:{start.column}. Only subscripts, names, and "
                    + "attributes are valid base classes for static typing.")
            new_bases.append(base.with_changes(value=new_value))

        self.annotations.class_definitions[
            node.name.value] = node.with_changes(bases=new_bases)

    def leave_ClassDef(
        self,
        original_node: cst.ClassDef,
    ) -> None:
        self.qualifier.pop()

    def visit_FunctionDef(
        self,
        node: cst.FunctionDef,
    ) -> bool:
        self.qualifier.append(node.name.value)
        returns = node.returns
        return_annotation = (self._handle_Annotation(
            annotation=returns) if returns is not None else None)
        parameter_annotations = self._handle_Parameters(node.params)
        name = ".".join(self.qualifier)
        key = FunctionKey.make(name, node.params)
        self.annotations.functions[key] = FunctionAnnotation(
            parameters=parameter_annotations, returns=return_annotation)

        # pyi files don't support inner functions, return False to stop the traversal.
        return False

    def leave_FunctionDef(
        self,
        original_node: cst.FunctionDef,
    ) -> None:
        self.qualifier.pop()

    def visit_AnnAssign(
        self,
        node: cst.AnnAssign,
    ) -> bool:
        name = get_full_name_for_node(node.target)
        if name is not None:
            self.qualifier.append(name)
        annotation_value = self._handle_Annotation(annotation=node.annotation)
        self.annotations.attributes[".".join(
            self.qualifier)] = annotation_value
        return True

    def leave_AnnAssign(
        self,
        original_node: cst.AnnAssign,
    ) -> None:
        self.qualifier.pop()

    def visit_Assign(
        self,
        node: cst.Assign,
    ) -> None:
        self.current_assign = node

    def leave_Assign(
        self,
        original_node: cst.Assign,
    ) -> None:
        self.current_assign = None

    @m.call_if_inside(m.Assign())
    @m.visit(m.Call(func=m.Name("TypeVar")))
    def record_typevar(
        self,
        node: cst.Call,
    ) -> None:
        # pyre-ignore current_assign is never None here
        name = get_full_name_for_node(self.current_assign.targets[0].target)
        if name is not None:
            # pyre-ignore current_assign is never None here
            self.annotations.typevars[name] = self.current_assign
            self._handle_qualification_and_should_qualify("typing.TypeVar")
            self.current_assign = None

    def leave_Module(
        self,
        original_node: cst.Module,
    ) -> None:
        self.annotations.finish()

    def _get_unique_qualified_name(
        self,
        node: cst.CSTNode,
    ) -> str:
        name = None
        names = [
            q.name for q in self.get_metadata(QualifiedNameProvider, node)
        ]
        if len(names) == 0:
            # we hit this branch if the stub is directly using a fully
            # qualified name, which is not technically valid python but is
            # convenient to allow.
            name = get_full_name_for_node(node)
        elif len(names) == 1 and isinstance(names[0], str):
            name = names[0]
        if name is None:
            start = self.get_metadata(PositionProvider, node).start
            raise ValueError(
                "Could not resolve a unique qualified name for type " +
                f"{get_full_name_for_node(node)} at {start.line}:{start.column}. "
                + f"Candidate names were: {names!r}")
        return name

    def _get_qualified_name_and_dequalified_node(
        self,
        node: Union[cst.Name, cst.Attribute],
    ) -> Tuple[str, Union[cst.Name, cst.Attribute]]:
        qualified_name = self._get_unique_qualified_name(node)
        dequalified_node = node.attr if isinstance(node,
                                                   cst.Attribute) else node
        return qualified_name, dequalified_node

    def _module_and_target(
        self,
        qualified_name: str,
    ) -> Tuple[str, str]:
        relative_prefix = ""
        while qualified_name.startswith("."):
            relative_prefix += "."
            qualified_name = qualified_name[1:]
        split = qualified_name.rsplit(".", 1)
        if len(split) == 1:
            qualifier, target = "", split[0]
        else:
            qualifier, target = split
        return (relative_prefix + qualifier, target)

    def _handle_qualification_and_should_qualify(
        self,
        qualified_name: str,
    ) -> bool:
        """
        Based on a qualified name and the existing module imports, record that
        we need to add an import if necessary and return whether or not we
        should use the qualified name due to a preexisting import.
        """
        module, target = self._module_and_target(qualified_name)
        if module in ("", "builtins"):
            return False
        elif qualified_name not in self.existing_imports:
            if module in self.existing_imports:
                return True
            else:
                AddImportsVisitor.add_needed_import(
                    self.context,
                    module,
                    target,
                )
                return False
        return False

    # Handler functions.
    #
    # Each of these does one of two things, possibly recursively, over some
    # valid CST node for a static type:
    #  - process the qualified name and ensure we will add necessary imports
    #  - dequalify the node

    def _handle_NameOrAttribute(
        self,
        node: NameOrAttribute,
    ) -> Union[cst.Name, cst.Attribute]:
        (
            qualified_name,
            dequalified_node,
        ) = self._get_qualified_name_and_dequalified_node(node)
        should_qualify = self._handle_qualification_and_should_qualify(
            qualified_name)
        self.annotations.names.add(qualified_name)
        if should_qualify:
            return node
        else:
            return dequalified_node

    def _handle_Index(
        self,
        slice: cst.Index,
    ) -> cst.Index:
        value = slice.value
        if isinstance(value, cst.Subscript):
            return slice.with_changes(value=self._handle_Subscript(value))
        elif isinstance(value, cst.Attribute):
            return slice.with_changes(
                value=self._handle_NameOrAttribute(value))
        else:
            if isinstance(value, cst.SimpleString):
                self.annotations.names.add(_get_string_value(value))
            return slice

    def _handle_Subscript(
        self,
        node: cst.Subscript,
    ) -> cst.Subscript:
        value = node.value
        if isinstance(value, NAME_OR_ATTRIBUTE):
            new_node = node.with_changes(
                value=self._handle_NameOrAttribute(value))
        else:
            raise ValueError("Expected any indexed type to have")
        if self._get_unique_qualified_name(node) in ("Type", "typing.Type"):
            # Note: we are intentionally not handling qualification of
            # anything inside `Type` because it's common to have nested
            # classes, which we cannot currently distinguish from classes
            # coming from other modules, appear here.
            return new_node
        slice = node.slice
        if isinstance(slice, tuple):
            new_slice = []
            for item in slice:
                value = item.slice.value
                if isinstance(value, NAME_OR_ATTRIBUTE):
                    name = self._handle_NameOrAttribute(item.slice.value)
                    new_index = item.slice.with_changes(value=name)
                    new_slice.append(item.with_changes(slice=new_index))
                else:
                    if isinstance(item.slice, cst.Index):
                        new_index = item.slice.with_changes(
                            value=self._handle_Index(item.slice))
                        item = item.with_changes(slice=new_index)
                    new_slice.append(item)
            return new_node.with_changes(slice=tuple(new_slice))
        elif isinstance(slice, cst.Index):
            new_slice = self._handle_Index(slice)
            return new_node.with_changes(slice=new_slice)
        else:
            return new_node

    def _handle_Annotation(
        self,
        annotation: cst.Annotation,
    ) -> cst.Annotation:
        node = annotation.annotation
        if isinstance(node, cst.SimpleString):
            self.annotations.names.add(_get_string_value(node))
            return annotation
        elif isinstance(node, cst.Subscript):
            return cst.Annotation(annotation=self._handle_Subscript(node))
        elif isinstance(node, NAME_OR_ATTRIBUTE):
            return cst.Annotation(
                annotation=self._handle_NameOrAttribute(node))
        else:
            raise ValueError(f"Unexpected annotation node: {node}")

    def _handle_Parameters(
        self,
        parameters: cst.Parameters,
    ) -> cst.Parameters:
        def update_annotations(
            parameters: Sequence[cst.Param], ) -> List[cst.Param]:
            updated_parameters = []
            for parameter in list(parameters):
                annotation = parameter.annotation
                if annotation is not None:
                    parameter = parameter.with_changes(
                        annotation=self._handle_Annotation(
                            annotation=annotation))
                updated_parameters.append(parameter)
            return updated_parameters

        return parameters.with_changes(
            params=update_annotations(parameters.params))
class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
    """
    Apply type annotations to a source module using the given stub mdules.
    You can also pass in explicit annotations for functions and attributes and
    pass in new class definitions that need to be added to the source module.

    This is one of the transforms that is available automatically to you when
    running a codemod. To use it in this manner, import
    :class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call
    the static
    :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context`
    method, giving it the current context (found as ``self.context`` for all
    subclasses of :class:`~libcst.codemod.Codemod`), the stub module from which
    you wish to add annotations.

    For example, you can store the type annotation ``int`` for ``x`` using::

        stub_module = parse_module("x: int = ...")

        ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module)

    You can apply the type annotation using::

        source_module = parse_module("x = 1")
        ApplyTypeAnnotationsVisitor.transform_module(source_module)

    This will produce the following code::

        x: int = 1

    If the function or attribute already has a type annotation, it will not be
    overwritten.

    To overwrite existing annotations when applying annotations from a stub,
    use the keyword argument ``overwrite_existing_annotations=True`` when
    constructing the codemod or when calling ``store_stub_in_context``.
    """

    CONTEXT_KEY = "ApplyTypeAnnotationsVisitor"

    def __init__(
        self,
        context: CodemodContext,
        annotations: Optional[Annotations] = None,
        overwrite_existing_annotations: bool = False,
        use_future_annotations: bool = False,
        strict_posargs_matching: bool = True,
        strict_annotation_matching: bool = False,
    ) -> None:
        super().__init__(context)
        # Qualifier for storing the canonical name of the current function.
        self.qualifier: List[str] = []
        self.annotations: Annotations = (Annotations.empty() if
                                         annotations is None else annotations)
        self.toplevel_annotations: Dict[str, cst.Annotation] = {}
        self.visited_classes: Set[str] = set()
        self.overwrite_existing_annotations = overwrite_existing_annotations
        self.use_future_annotations = use_future_annotations
        self.strict_posargs_matching = strict_posargs_matching
        self.strict_annotation_matching = strict_annotation_matching

        # We use this to determine the end of the import block so that we can
        # insert top-level annotations.
        self.import_statements: List[cst.ImportFrom] = []

        # We use this to report annotations added, as well as to determine
        # whether to abandon the codemod in edge cases where we may have
        # only made changes to the imports.
        self.annotation_counts: AnnotationCounts = AnnotationCounts()

        # We use this to collect typevars, to avoid importing existing ones from the pyi file
        self.current_assign: Optional[cst.Assign] = None
        self.typevars: Dict[str, cst.Assign] = {}

    @staticmethod
    def store_stub_in_context(
        context: CodemodContext,
        stub: cst.Module,
        overwrite_existing_annotations: bool = False,
        use_future_annotations: bool = False,
        strict_posargs_matching: bool = True,
        strict_annotation_matching: bool = False,
    ) -> None:
        """
        Store a stub module in the :class:`~libcst.codemod.CodemodContext` so
        that type annotations from the stub can be applied in a later
        invocation of this class.

        If the ``overwrite_existing_annotations`` flag is ``True``, the
        codemod will overwrite any existing annotations.

        If you call this function multiple times, only the last values of
        ``stub`` and ``overwrite_existing_annotations`` will take effect.
        """
        context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = (
            stub,
            overwrite_existing_annotations,
            use_future_annotations,
            strict_posargs_matching,
            strict_annotation_matching,
        )

    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree

    # smart constructors: all applied annotations happen via one of these

    def _apply_annotation_to_attribute_or_global(
        self,
        name: str,
        annotation: cst.Annotation,
        value: Optional[cst.BaseExpression],
    ) -> cst.AnnAssign:
        if len(self.qualifier) == 0:
            self.annotation_counts.global_annotations += 1
        else:
            self.annotation_counts.attribute_annotations += 1
        return cst.AnnAssign(
            cst.Name(name),
            annotation,
            value,
        )

    def _apply_annotation_to_parameter(
        self,
        parameter: cst.Param,
        annotation: cst.Annotation,
    ) -> cst.Param:
        self.annotation_counts.parameter_annotations += 1
        return parameter.with_changes(annotation=annotation, )

    def _apply_annotation_to_return(
        self,
        function_def: cst.FunctionDef,
        annotation: cst.Annotation,
    ) -> cst.FunctionDef:
        self.annotation_counts.return_annotations += 1
        return function_def.with_changes(returns=annotation)

    # private methods used in the visit and leave methods

    def _qualifier_name(self) -> str:
        return ".".join(self.qualifier)

    def _annotate_single_target(
        self,
        node: cst.Assign,
        updated_node: cst.Assign,
    ) -> Union[cst.Assign, cst.AnnAssign]:
        only_target = node.targets[0].target
        if isinstance(only_target, (cst.Tuple, cst.List)):
            for element in only_target.elements:
                value = element.value
                name = get_full_name_for_node(value)
                if name is not None and name != "_":
                    self._add_to_toplevel_annotations(name)
        elif isinstance(only_target, (cst.Subscript)):
            pass
        else:
            name = get_full_name_for_node(only_target)
            if name is not None:
                self.qualifier.append(name)
                if (self._qualifier_name() in self.annotations.attributes
                        and not isinstance(only_target, cst.Subscript)):
                    annotation = self.annotations.attributes[
                        self._qualifier_name()]
                    self.qualifier.pop()
                    return self._apply_annotation_to_attribute_or_global(
                        name=name,
                        annotation=annotation,
                        value=node.value,
                    )
                else:
                    self.qualifier.pop()
        return updated_node

    def _split_module(
        self,
        module: cst.Module,
        updated_module: cst.Module,
    ) -> Tuple[List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]],
               List[Union[cst.SimpleStatementLine,
                          cst.BaseCompoundStatement]], ]:
        import_add_location = 0
        # This works under the principle that while we might modify node contents,
        # we have yet to modify the number of statements. So we can match on the
        # original tree but break up the statements of the modified tree. If we
        # change this assumption in this visitor, we will have to change this code.
        for i, statement in enumerate(module.body):
            if isinstance(statement, cst.SimpleStatementLine):
                for possible_import in statement.body:
                    for last_import in self.import_statements:
                        if possible_import is last_import:
                            import_add_location = i + 1
                            break

        return (
            list(updated_module.body[:import_add_location]),
            list(updated_module.body[import_add_location:]),
        )

    def _add_to_toplevel_annotations(
        self,
        name: str,
    ) -> None:
        self.qualifier.append(name)
        if self._qualifier_name() in self.annotations.attributes:
            annotation = self.annotations.attributes[self._qualifier_name()]
            self.toplevel_annotations[name] = annotation
        self.qualifier.pop()

    def _update_parameters(
        self,
        annotations: FunctionAnnotation,
        updated_node: cst.FunctionDef,
    ) -> cst.Parameters:
        # Update params and default params with annotations
        # Don't override existing annotations or default values unless asked
        # to overwrite existing annotations.
        def update_annotation(
            parameters: Sequence[cst.Param],
            annotations: Sequence[cst.Param],
            positional: bool,
        ) -> List[cst.Param]:
            parameter_annotations = {}
            annotated_parameters = []
            positional = positional and not self.strict_posargs_matching
            for i, parameter in enumerate(annotations):
                key = i if positional else parameter.name.value
                if parameter.annotation:
                    parameter_annotations[
                        key] = parameter.annotation.with_changes(
                            whitespace_before_indicator=cst.SimpleWhitespace(
                                value=""))
            for i, parameter in enumerate(parameters):
                key = i if positional else parameter.name.value
                if key in parameter_annotations and (
                        self.overwrite_existing_annotations
                        or not parameter.annotation):
                    parameter = self._apply_annotation_to_parameter(
                        parameter=parameter,
                        annotation=parameter_annotations[key],
                    )
                annotated_parameters.append(parameter)
            return annotated_parameters

        return updated_node.params.with_changes(
            params=update_annotation(
                updated_node.params.params,
                annotations.parameters.params,
                positional=True,
            ),
            kwonly_params=update_annotation(
                updated_node.params.kwonly_params,
                annotations.parameters.kwonly_params,
                positional=False,
            ),
            posonly_params=update_annotation(
                updated_node.params.posonly_params,
                annotations.parameters.posonly_params,
                positional=True,
            ),
        )

    def _insert_empty_line(
        self,
        statements: List[Union[cst.SimpleStatementLine,
                               cst.BaseCompoundStatement]],
    ) -> List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]]:
        if len(statements) < 1:
            # No statements, nothing to add to
            return statements
        if len(statements[0].leading_lines) == 0:
            # Statement has no leading lines, add one!
            return [
                statements[0].with_changes(leading_lines=(cst.EmptyLine(), )),
                *statements[1:],
            ]
        if statements[0].leading_lines[0].comment is None:
            # First line is empty, so its safe to leave as-is
            return statements
        # Statement has a comment first line, so lets add one more empty line
        return [
            statements[0].with_changes(
                leading_lines=(cst.EmptyLine(), *statements[0].leading_lines)),
            *statements[1:],
        ]

    def _match_signatures(  # noqa: C901: Too complex
        self,
        function: cst.FunctionDef,
        annotations: FunctionAnnotation,
    ) -> bool:
        """Check that function annotations on both signatures are compatible."""
        def compatible(
            p: Optional[cst.Annotation],
            q: Optional[cst.Annotation],
        ) -> bool:
            if (self.overwrite_existing_annotations or not _is_non_sentinel(p)
                    or not _is_non_sentinel(q)):
                return True
            if not self.strict_annotation_matching:
                # We will not overwrite clashing annotations, but the signature as a
                # whole will be marked compatible so that holes can be filled in.
                return True
            return p.annotation.deep_equals(q.annotation)  # pyre-ignore[16]

        def match_posargs(
            ps: Sequence[cst.Param],
            qs: Sequence[cst.Param],
        ) -> bool:
            if len(ps) != len(qs):
                return False
            for p, q in zip(ps, qs):
                if self.strict_posargs_matching and not p.name.value == q.name.value:
                    return False
                if not compatible(p.annotation, q.annotation):
                    return False
            return True

        def match_kwargs(
            ps: Sequence[cst.Param],
            qs: Sequence[cst.Param],
        ) -> bool:
            ps_dict = {x.name.value: x for x in ps}
            qs_dict = {x.name.value: x for x in qs}
            if set(ps_dict.keys()) != set(qs_dict.keys()):
                return False
            for k in ps_dict.keys():
                if not compatible(ps_dict[k].annotation,
                                  qs_dict[k].annotation):
                    return False
            return True

        def match_star(
            p: StarParamType,
            q: StarParamType,
        ) -> bool:
            return _is_non_sentinel(p) == _is_non_sentinel(q)

        def match_params(
            f: cst.FunctionDef,
            g: FunctionAnnotation,
        ) -> bool:
            p, q = f.params, g.parameters
            return (match_posargs(p.params, q.params)
                    and match_posargs(p.posonly_params, q.posonly_params)
                    and match_kwargs(p.kwonly_params, q.kwonly_params)
                    and match_star(p.star_arg, q.star_arg)
                    and match_star(p.star_kwarg, q.star_kwarg))

        def match_return(
            f: cst.FunctionDef,
            g: FunctionAnnotation,
        ) -> bool:
            return compatible(f.returns, g.returns)

        return match_params(function, annotations) and match_return(
            function, annotations)

    # transform API methods

    def visit_ClassDef(
        self,
        node: cst.ClassDef,
    ) -> None:
        self.qualifier.append(node.name.value)
        self.visited_classes.add(node.name.value)

    def leave_ClassDef(
        self,
        original_node: cst.ClassDef,
        updated_node: cst.ClassDef,
    ) -> cst.ClassDef:
        cls_name = ".".join(self.qualifier)
        self.qualifier.pop()
        definition = self.annotations.class_definitions.get(cls_name)
        if definition:
            b1 = _find_generic_base(definition)
            b2 = _find_generic_base(updated_node)
            if b1 and not b2:
                new_bases = list(updated_node.bases) + [b1]
                self.annotation_counts.typevars_and_generics_added += 1
                return updated_node.with_changes(bases=new_bases)
        return updated_node

    def visit_FunctionDef(
        self,
        node: cst.FunctionDef,
    ) -> bool:
        self.qualifier.append(node.name.value)
        # pyi files don't support inner functions, return False to stop the traversal.
        return False

    def leave_FunctionDef(
        self,
        original_node: cst.FunctionDef,
        updated_node: cst.FunctionDef,
    ) -> cst.FunctionDef:
        key = FunctionKey.make(self._qualifier_name(), updated_node.params)
        self.qualifier.pop()
        if key in self.annotations.functions:
            function_annotation = self.annotations.functions[key]
            # Only add new annotation if:
            # * we have matching function signatures and
            # * we are explicitly told to overwrite existing annotations or
            # * there is no existing annotation
            if not self._match_signatures(updated_node, function_annotation):
                return updated_node
            set_return_annotation = (self.overwrite_existing_annotations
                                     or updated_node.returns is None)
            if set_return_annotation and function_annotation.returns is not None:
                updated_node = self._apply_annotation_to_return(
                    function_def=updated_node,
                    annotation=function_annotation.returns,
                )
            # Don't override default values when annotating functions
            new_parameters = self._update_parameters(function_annotation,
                                                     updated_node)
            return updated_node.with_changes(params=new_parameters)
        return updated_node

    def visit_Assign(
        self,
        node: cst.Assign,
    ) -> None:
        self.current_assign = node

    @m.call_if_inside(m.Assign())
    @m.visit(m.Call(func=m.Name("TypeVar")))
    def record_typevar(
        self,
        node: cst.Call,
    ) -> None:
        # pyre-ignore current_assign is never None here
        name = get_full_name_for_node(self.current_assign.targets[0].target)
        if name is not None:
            # Preserve the whole node, even though we currently just use the
            # name, so that we can match bounds and variance at some point and
            # determine if two typevars with the same name are indeed the same.

            # pyre-ignore current_assign is never None here
            self.typevars[name] = self.current_assign
            self.current_assign = None

    def leave_Assign(
        self,
        original_node: cst.Assign,
        updated_node: cst.Assign,
    ) -> Union[cst.Assign, cst.AnnAssign]:

        self.current_assign = None

        if len(original_node.targets) > 1:
            for assign in original_node.targets:
                target = assign.target
                if isinstance(target, (cst.Name, cst.Attribute)):
                    name = get_full_name_for_node(target)
                    if name is not None and name != "_":
                        # Add separate top-level annotations for `a = b = 1`
                        # as `a: int` and `b: int`.
                        self._add_to_toplevel_annotations(name)
            return updated_node
        else:
            return self._annotate_single_target(original_node, updated_node)

    def leave_ImportFrom(
        self,
        original_node: cst.ImportFrom,
        updated_node: cst.ImportFrom,
    ) -> cst.ImportFrom:
        self.import_statements.append(original_node)
        return updated_node

    def leave_Module(
        self,
        original_node: cst.Module,
        updated_node: cst.Module,
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]

        # NOTE: The entire change will also be abandoned if
        # self.annotation_counts is all 0s, so if adding any new category make
        # sure to record it there.
        if not (self.toplevel_annotations or fresh_class_definitions
                or self.annotations.typevars):
            return updated_node

        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = self._apply_annotation_to_attribute_or_global(
                name=name,
                annotation=annotation,
                value=None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        # TypeVar definitions could be scattered through the file, so do not
        # attempt to put new ones with existing ones, just add them at the top.
        typevars = {
            k: v
            for k, v in self.annotations.typevars.items()
            if k not in self.typevars
        }
        if typevars:
            for var, stmt in typevars.items():
                toplevel_statements.append(cst.Newline())
                toplevel_statements.append(stmt)
                self.annotation_counts.typevars_and_generics_added += 1
            toplevel_statements.append(cst.Newline())

        self.annotation_counts.classes_added = len(fresh_class_definitions)
        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
    comment=m.Comment(m.MatchIfTrue(is_valid_comment)),
    newline=m.Newline(),
)

field_without_comment = m.SimpleStatementLine(
    body=[
        m.Assign(value=(m.Call(
            args=[
                m.ZeroOrMore(),
                m.Arg(keyword=m.Name('null'), value=m.Name('True')),
                m.ZeroOrMore(),
            ],
            whitespace_before_args=m.DoesNotMatch(
                m.ParenthesizedWhitespace(null_comment)),
        )
                        | m.Call(
                            func=m.Attribute(attr=m.Name('NullBooleanField')),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
                        )
                        | m.Call(
                            func=m.Name('NullBooleanField'),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
                        )))
    ],
    trailing_whitespace=m.DoesNotMatch(null_comment),
)


class FieldValidator(m.MatcherDecoratableVisitor):
示例#21
0
import libcst as cst
import libcst.matchers as m
from libcst.metadata import ScopeProvider, ParentNodeProvider
import inspect

from ..tracer import TracerArgs
from ..common import SEP, parse_expr, a2s, EvalException
from .base_pass import BasePass

obj_new_pattern = m.Assign(
    targets=[m.AssignTarget(m.Name())],
    value=m.Call(func=m.Attribute(value=m.Name(), attr=m.Name("__new__"))))


class FindSafeObjsToConvert(cst.CSTVisitor):
    METADATA_DEPENDENCIES = (ScopeProvider, ParentNodeProvider)

    def __init__(self, pass_):
        self.pass_ = pass_
        self.whitelist = set()
        self.blacklist = set()

    def visit_Assign(self, node):
        if m.matches(node, obj_new_pattern):
            name = node.targets[0].target.value
            if name in self.pass_.globls:
                obj = self.pass_.globls[name]

                scope = self.get_metadata(ScopeProvider, node)
                for access in scope.accesses[name]:
                    parent = self.get_metadata(ParentNodeProvider, access.node)
示例#22
0
 def leave_Assign(self, original_node, updated_node):
     if m.matches(original_node,
                  m.Assign(targets=[m.AssignTarget(m.Name())])):
         if self.unused_vars[original_node] and is_pure(updated_node.value):
             return cst.RemoveFromParent()
     return updated_node