class TestVisitor(MatcherDecoratableTransformer):
            def __init__(self) -> None:
                super().__init__()
                self.visits: List[str] = []

            @call_if_inside(m.ClassDef(m.Name("A")))
            @call_if_inside(m.FunctionDef(m.Name("foo")))
            def visit_SimpleString(self, node: cst.SimpleString) -> None:
                self.visits.append(node.value)
Esempio n. 2
0
class DeprecatedModelFieldValidator(m.MatcherDecoratableVisitor):
    METADATA_DEPENDENCIES = (PositionProvider, )

    def __init__(
        self,
        model_file_path: str,
        valid_deprecation_comment_pattern: re.Pattern,
        deprecation_comment_marker_pattern: re.Pattern,
    ):
        super().__init__()

        self.model_file_path = model_file_path
        self.valid_deprecation_comment_pattern = valid_deprecation_comment_pattern
        self.deprecation_comment_marker_pattern = deprecation_comment_marker_pattern
        self.errors: typing.List[Error] = []

    def is_deprecation_comment(self, comment: str) -> bool:
        return self.deprecation_comment_marker_pattern.search(
            comment) is not None

    def is_valid_deprecation_comment(self, comment: str) -> bool:
        return self.valid_deprecation_comment_pattern.search(
            comment) is not None

    @m.call_if_inside(m.ClassDef())
    def visit_SimpleStatementLine(
            self, node: cst.SimpleStatementLine) -> None:  # noqa: N802 C901
        if not self.matches(node, django_model_field_with_comments):
            return None

        leading_comment = get_leading_comment(node)
        trailing_comment = get_trailing_comment(node)

        if leading_comment and self.is_deprecation_comment(leading_comment):
            is_valid_deprecation_comment = self.is_valid_deprecation_comment(
                leading_comment)
        elif trailing_comment and self.is_deprecation_comment(
                trailing_comment):
            is_valid_deprecation_comment = self.is_valid_deprecation_comment(
                trailing_comment)
        else:
            return None

        if not is_valid_deprecation_comment:
            position = self.get_metadata(PositionProvider, node)
            self.errors.append(
                Error(
                    self.model_file_path,
                    position.start.line,
                    position.start.column,
                    get_model_field_name(node),
                ))

    def run_for_module(self,
                       module: cst.Module) -> DeprecatedModelFieldValidator:
        cst.MetadataWrapper(module).visit(self)
        return self
Esempio n. 3
0
class FieldValidator(m.MatcherDecoratableVisitor):
    errors: List[Error] = []

    METADATA_DEPENDENCIES = (PositionProvider, )

    @m.call_if_inside(m.ClassDef())
    def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> None:
        if self.matches(node, field_without_comment):
            position = self.get_metadata(PositionProvider, node).start
            field_name = cast(Assign, node.body[0]).targets[0].target.value
            self.errors.append(
                Error(position.line, position.column, field_name))
Esempio n. 4
0
 def _has_testnode(node: cst.Module) -> bool:
     return m.matches(
         node,
         m.Module(body=[
             # Sequence wildcard matchers matches LibCAST nodes in a row in a
             # sequence. It does not implicitly match on partial sequences. So,
             # when matching against a sequence we will need to provide a
             # complete pattern. This often means using helpers such as
             # ``ZeroOrMore()`` as the first and last element of the sequence.
             m.ZeroOrMore(),
             m.AtLeastN(
                 n=1,
                 matcher=m.OneOf(
                     m.FunctionDef(name=m.Name(value=m.MatchIfTrue(
                         lambda value: value.startswith("test_")))),
                     m.ClassDef(name=m.Name(value=m.MatchIfTrue(
                         lambda value: value.startswith("Test")))),
                 ),
             ),
             m.ZeroOrMore(),
         ]),
     )
Esempio n. 5
0
class Modernizer(m.MatcherDecoratableTransformer):
    METADATA_DEPENDENCIES = (PositionProvider,)
    # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node

    def __init__(
        self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None
    ):
        super().__init__()
        self.path = path
        self.verbose = verbose
        self.ignored = set(ignored or [])
        self.errors = False
        self.stack: List[Tuple[str, ...]] = []
        self.annotations: Dict[
            Tuple[str, ...], Comment  # key: tuple of canonical variable name
        ] = {}
        self.python_future_updated_node: Optional[SimpleStatementLine] = None
        self.python_future_imports: Dict[str, str] = {}
        self.python_future_new_imports: Set[str] = set()
        self.builtins_imports: Dict[str, str] = {}
        self.builtins_new_imports: Set[str] = set()
        self.builtins_updated_node: Optional[SimpleStatementLine] = None
        self.future_utils_imports: Dict[str, str] = {}
        self.future_utils_new_imports: Set[str] = set()
        self.future_utils_updated_node: Optional[SimpleStatementLine] = None
        # self.last_import_node: Optional[CSTNode] = None
        self.last_import_node_stmt: Optional[CSTNode] = None

    # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__")))
    # @m.visit(m.ImportAlias() | m.ImportStar())
    # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None:
    #     self.add_import(self.python_future_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("__future__")))
    # def import_python_future_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(m.ImportFrom(module=m.Name("builtins")))
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.builtins_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("builtins")))
    # def builtins_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(
        m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    )
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.future_utils_imports, node)

    # @m.leave(
    #     m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    # )
    # def future_utils_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @staticmethod
    def add_import(
        imports: Dict[str, str], node: Union[ImportAlias, ImportStar]
    ) -> None:
        if isinstance(node, ImportAlias):
            imports[node.name.value] = (
                node.asname.name.value if node.asname else node.name.value
            )
        else:
            imports["*"] = "*"

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_Import(self, node: Import) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]:
        for n in node.body:
            if m.matches(n, m.Import() | m.ImportFrom()):
                self.last_import_node_stmt = node
        return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def leave_SimpleStatementLine(
        self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine
    ) -> Union[BaseStatement, RemovalSentinel]:
        for n in updated_node.body:
            if m.matches(n, m.ImportFrom(module=m.Name("__future__"))):
                self.python_future_updated_node = updated_node
            elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))):
                self.builtins_updated_node = updated_node
            elif m.matches(
                n,
                m.ImportFrom(
                    module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))
                ),
            ):
                self.future_utils_updated_node = updated_node
        return updated_node

    # @m.visit(
    #     m.AllOf(
    #         m.SimpleStatementLine(),
    #         m.MatchIfTrue(
    #             lambda node: any(m.matches(c, m.Assign()) for c in node.children)
    #         ),
    #         m.MatchIfTrue(
    #             lambda node: "# type:" in node.trailing_whitespace.comment.value
    #         ),
    #     )
    # )
    # def visit_assign(self, node: SimpleStatementSuite) -> None:
    #     return None

    def visit_Param(self, node: Param) -> Optional[bool]:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()
                self.ptype: Optional[str] = None

            def visit_TrailingWhitespace_comment(
                self, node: "TrailingWhitespace"
            ) -> None:
                if node.comment and "type:" in node.comment.value:
                    mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                    self.ptype = mo.group(1) if mo else None
                return None

        v = Visitor()
        node.visit(v)
        if self.verbose:
            pos = self.get_metadata(PositionProvider, node).start
            print(
                f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}"
            )
        return None

    @m.visit(m.SimpleStatementLine())
    def visit_simple_stmt(self, node: SimpleStatementLine) -> None:
        assign = None
        for c in node.children:
            if m.matches(c, m.Assign()):
                assign = ensure_type(c, Assign)
        if assign:
            if m.MatchIfTrue(
                lambda n: n.trailing_whitespace.comment
                and "type:" in n.trailing_whitespace.comment.value
            ):

                class TypingVisitor(m.MatcherDecoratableVisitor):
                    def __init__(self):
                        super().__init__()
                        self.vtype = None

                    def visit_TrailingWhitespace_comment(
                        self, node: "TrailingWhitespace"
                    ) -> None:
                        if node.comment:
                            mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                            if mo:
                                vtype = mo.group(1)
                        return None

                tv = TypingVisitor()
                node.visit(tv)
                vtype = tv.vtype
            else:
                vtype = None

            class NameVisitor(m.MatcherDecoratableVisitor):
                def __init__(self):
                    super().__init__()
                    self.names: List[str] = []

                def visit_Name(self, node: Name) -> Optional[bool]:
                    self.names.append(node.value)
                    return None

            if self.verbose:
                pos = self.get_metadata(PositionProvider, node).start
                for target in assign.targets:
                    v = NameVisitor()
                    target.visit(v)
                    for name in v.names:
                        print(
                            f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}"
                        )

    def visit_FunctionDef_body(self, node: FunctionDef) -> None:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()

            def visit_EmptyLine_comment(self, node: "EmptyLine") -> None:
                # FIXME too many matches on test_param_02
                if not node.comment:
                    return
                # TODO: use comment.value
                return None

        v = Visitor()
        node.visit(v)
        return None

    map_matcher = m.Call(
        func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range")
    )

    @m.visit(map_matcher)
    def visit_map(self, node: Call) -> None:
        func_name = ensure_type(node.func, Name).value
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(map_matcher)
    def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression:
        # TODO test with CompFor etc.
        # TODO improve join test
        func_name = ensure_type(updated_node.func, Name).value
        if func_name not in self.builtins_imports:
            updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def visit_xrange(self, node: Call) -> None:
        orig_func_name = ensure_type(node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression:
        orig_func_name = ensure_type(updated_node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        return updated_node.with_changes(func=Name(func_name))

    iter_matcher = m.Call(
        func=m.Attribute(
            attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems")
        )
    )

    @m.visit(iter_matcher)
    def visit_iter(self, node: Call) -> None:
        func_name = ensure_type(node.func, Attribute).attr.value
        if func_name not in self.future_utils_imports:
            self.future_utils_new_imports.add(func_name)

    @m.leave(iter_matcher)
    def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        attribute = ensure_type(updated_node.func, Attribute)
        func_name = attribute.attr
        dict_name = attribute.value
        return updated_node.with_changes(func=func_name, args=[Arg(dict_name)])

    not_iter_matcher = m.Call(
        func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items"))
    )

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(not_iter_matcher)
    def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.call_if_not_inside(m.Import() | m.ImportFrom())
    @m.leave(m.Name(value="unicode"))
    def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression:
        value = "text_type"
        if value not in self.future_utils_imports:
            self.future_utils_new_imports.add(value)
        return updated_node.with_changes(value=value)

    def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "builtins",
            self.builtins_updated_node,
            self.builtins_imports,
            self.builtins_new_imports,
            True,
        )
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "future.utils",
            self.future_utils_updated_node,
            self.future_utils_imports,
            self.future_utils_new_imports,
            False,
        )
        return updated_node

    def update_imports(
        self,
        original_module: Module,
        updated_module: Module,
        import_name: str,
        updated_import_node: SimpleStatementLine,
        current_imports: Dict[str, str],
        new_imports: Set[str],
        noqa: bool,
    ) -> Module:
        if not new_imports:
            return updated_module
        noqa_comment = "  # noqa" if noqa else ""
        if not updated_import_node:
            i = -1
            blank_lines = "\n\n"
            if self.last_import_node_stmt:
                blank_lines = ""
                for i, (original, updated) in enumerate(
                    zip(original_module.body, updated_module.body)
                ):
                    if original is self.last_import_node_stmt:
                        break
            stmt = parse_module(
                f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}",
                config=updated_module.config_for_parsing,
            )
            body = list(updated_module.body)
            self.last_import_node_stmt = stmt
            return updated_module.with_changes(
                body=body[: i + 1] + stmt.children + body[i + 1 :]
            )
        else:
            if "*" not in current_imports:
                current_imports_set = {
                    f"{k}" if k == v else f"{k} as {v}"
                    for k, v in current_imports.items()
                }
                stmt = parse_statement(
                    f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}"
                )
                return updated_module.deep_replace(updated_import_node, stmt)
                # for i, (original, updated) in enumerate(
                #     zip(original_module.body, updated_module.body)
                # ):
                #     if original is original_import_node:
                #         body = list(updated_module.body)
                #         return updated_module.with_changes(
                #             body=body[:i] + [stmt] + body[i + 1 :]
                #         )
        return updated_module
Esempio n. 6
0
class ConvertTypeComments(VisitorBasedCodemodCommand):
    DESCRIPTION = """
    Codemod that converts type comments into Python 3.6+ style
    annotations.

    Notes:
    - This transform requires using the `ast` module, which is not compatible
      with multiprocessing. So you should run using a recent version of python,
      and set `--jobs=1` if using `python -m libcst.tool codemod ...` from the
      commandline.
    - This transform requires capabilities from `ast` that are not available
      prior to Python 3.9, so libcst must run on Python 3.9+. The code you are
      transforming can by Python 3.6+, this limitation applies only to libcst
      itself.

    We can handle type comments in the following statement types:
    - Assign
      - This is converted into a single AnnAssign when possible
      - In more complicated cases it will produce multiple AnnAssign
        nodes with no value (i.e. "type declaration" statements)
        followed by an Assign
    - For and With
      - We prepend both of these with type declaration statements.
    - FunctionDef
      - We apply all the types we can find. If we find several:
        - We prefer any existing annotations to type comments
        - For parameters, we prefer inline type comments to
          function-level type comments if we find both.

    We always apply the type comments as quote_annotations annotations, unless
    we know that it refers to a builtin. We do not guarantee that
    the resulting string annotations would parse, but they should
    never cause failures at module import time.

    We attempt to:
    - Always strip type comments for statements where we successfully
      applied types.
    - Never strip type comments for statements where we failed to
      apply types.

    There are many edge case possible where the arity of a type
    hint (which is either a tuple or a func_type) might not match
    the code. In these cases we generally give up:
    - For Assign, For, and With, we require that every target of
      bindings (e.g. a tuple of names being bound) must have exactly
      the same arity as the comment.
      - So, for example, we would skip an assignment statement such as
        ``x = y, z = 1, 2  # type: int, int`` because the arity
        of ``x`` does not match the arity of the hint.
    - For FunctionDef, we do *not* check arity of inline parameter
      type comments but we do skip the transform if the arity of
      the function does not match the function-level comment.
    """

    # Finding the location of a type comment in a FunctionDef is difficult.
    #
    # As a result, if when visiting a FunctionDef header we are able to
    # successfully extrct type information then we aggressively strip type
    # comments until we reach the first statement in the body.
    #
    # Once we get there we have to stop, so that we don't unintentionally remove
    # unprocessed type comments.
    #
    # This state handles tracking everything we need for this.
    function_type_info_stack: List[FunctionTypeInfo]
    function_body_stack: List[cst.BaseSuite]
    aggressively_strip_type_comments: bool

    @staticmethod
    def add_args(arg_parser: argparse.ArgumentParser) -> None:
        arg_parser.add_argument(
            "--no-quote-annotations",
            action="store_true",
            help=(
                "Add unquoted annotations. This leads to prettier code "
                + "but possibly more errors if type comments are invalid."
            ),
        )

    def __init__(
        self,
        context: CodemodContext,
        no_quote_annotations: bool = False,
    ) -> None:
        if (sys.version_info.major, sys.version_info.minor) < (3, 9):
            # The ast module did not get `unparse` until Python 3.9,
            # or `type_comments` until Python 3.8
            #
            # For earlier versions of python, raise early instead of failing
            # later. It might be possible to use libcst parsing and the
            # typed_ast library to support earlier python versions, but this is
            # not a high priority.
            raise NotImplementedError(
                "You are trying to run ConvertTypeComments, but libcst "
                + "needs to be running with Python 3.9+ in order to "
                + "do this. Try using Python 3.9+ to run your codemod. "
                + "Note that the target code can be using Python 3.6+, "
                + "it is only libcst that needs a new Python version."
            )
        super().__init__(context)
        # flags used to control overall behavior
        self.quote_annotations: bool = not no_quote_annotations
        # state used to manage how we traverse nodes in various contexts
        self.function_type_info_stack = []
        self.function_body_stack = []
        self.aggressively_strip_type_comments = False

    def _strip_TrailingWhitespace(
        self,
        node: cst.TrailingWhitespace,
    ) -> cst.TrailingWhitespace:
        return node.with_changes(
            whitespace=cst.SimpleWhitespace(
                ""
            ),  # any whitespace came before the comment, so strip it.
            comment=None,
        )

    def leave_SimpleStatementLine(
        self,
        original_node: cst.SimpleStatementLine,
        updated_node: cst.SimpleStatementLine,
    ) -> Union[cst.SimpleStatementLine, cst.FlattenSentinel]:
        """
        Convert any SimpleStatementLine containing an Assign with a
        type comment into a one that uses a PEP 526 AnnAssign.
        """
        # determine whether to apply an annotation
        assign = updated_node.body[-1]
        if not isinstance(assign, cst.Assign):  # only Assign matters
            return updated_node
        annotation = _annotation_for_statement(original_node)
        if annotation is None:
            return updated_node
        # At this point have a single-line Assign with a type comment.
        # Convert it to an AnnAssign and strip the comment.
        converted = convert_Assign(
            node=assign,
            annotation=annotation,
            quote_annotations=self.quote_annotations,
        )
        if isinstance(converted, _FailedToApplyAnnotation):
            # We were unable to consume the type comment, so return the
            # original code unchanged.
            # TODO: allow stripping the invalid type comments via a flag
            return updated_node
        elif isinstance(converted, cst.AnnAssign):
            # We were able to convert the Assign into an AnnAssign, so
            # we can update the node.
            return updated_node.with_changes(
                body=[*updated_node.body[:-1], converted],
                trailing_whitespace=self._strip_TrailingWhitespace(
                    updated_node.trailing_whitespace,
                ),
            )
        elif isinstance(converted, list):
            # We need to inject two or more type declarations.
            #
            # In this case, we need to split across multiple lines, and
            # this also means we'll spread any multi-statement lines out
            # (multi-statement lines are PEP 8 violating anyway).
            #
            # We still preserve leading lines from before our transform.
            new_statements = [
                *(
                    statement.with_changes(
                        semicolon=cst.MaybeSentinel.DEFAULT,
                    )
                    for statement in updated_node.body[:-1]
                ),
                *converted,
            ]
            if len(new_statements) < 2:
                raise RuntimeError("Unreachable code.")
            return cst.FlattenSentinel(
                [
                    updated_node.with_changes(
                        body=[new_statements[0]],
                        trailing_whitespace=self._strip_TrailingWhitespace(
                            updated_node.trailing_whitespace,
                        ),
                    ),
                    *(
                        cst.SimpleStatementLine(body=[statement])
                        for statement in new_statements[1:]
                    ),
                ]
            )
        else:
            raise RuntimeError(f"Unhandled value {converted}")

    def leave_For(
        self,
        original_node: cst.For,
        updated_node: cst.For,
    ) -> Union[cst.For, cst.FlattenSentinel]:
        """
        Convert a For with a type hint on the bound variable(s) to
        use type declarations.
        """
        # Type comments are only possible when the body is an indented
        # block, and we need this refinement to work with the header,
        # so we check and only then extract the type comment.
        body = updated_node.body
        if not isinstance(body, cst.IndentedBlock):
            return updated_node
        annotation = _annotation_for_statement(original_node)
        if annotation is None:
            return updated_node
        # Zip up the type hint and the bindings. If we hit an arity
        # error, abort.
        try:
            type_declarations = AnnotationSpreader.type_declaration_statements(
                bindings=AnnotationSpreader.unpack_target(updated_node.target),
                annotations=AnnotationSpreader.unpack_annotation(annotation),
                leading_lines=updated_node.leading_lines,
                quote_annotations=self.quote_annotations,
            )
        except _ArityError:
            return updated_node
        # There is no arity error, so we can add the type delaration(s)
        return cst.FlattenSentinel(
            [
                *type_declarations,
                updated_node.with_changes(
                    body=body.with_changes(
                        header=self._strip_TrailingWhitespace(body.header)
                    ),
                    leading_lines=[],
                ),
            ]
        )

    def leave_With(
        self,
        original_node: cst.With,
        updated_node: cst.With,
    ) -> Union[cst.With, cst.FlattenSentinel]:
        """
        Convert a With with a type hint on the bound variable(s) to
        use type declarations.
        """
        # Type comments are only possible when the body is an indented
        # block, and we need this refinement to work with the header,
        # so we check and only then extract the type comment.
        body = updated_node.body
        if not isinstance(body, cst.IndentedBlock):
            return updated_node
        annotation = _annotation_for_statement(original_node)
        if annotation is None:
            return updated_node
        # PEP 484 does not attempt to specify type comment semantics for
        # multiple with bindings (there's more than one sensible way to
        # do it), so we make no attempt to handle this
        targets = [
            item.asname.name for item in updated_node.items if item.asname is not None
        ]
        if len(targets) != 1:
            return updated_node
        target = targets[0]
        # Zip up the type hint and the bindings. If we hit an arity
        # error, abort.
        try:
            type_declarations = AnnotationSpreader.type_declaration_statements(
                bindings=AnnotationSpreader.unpack_target(target),
                annotations=AnnotationSpreader.unpack_annotation(annotation),
                leading_lines=updated_node.leading_lines,
                quote_annotations=self.quote_annotations,
            )
        except _ArityError:
            return updated_node
        # There is no arity error, so we can add the type delaration(s)
        return cst.FlattenSentinel(
            [
                *type_declarations,
                updated_node.with_changes(
                    body=body.with_changes(
                        header=self._strip_TrailingWhitespace(body.header)
                    ),
                    leading_lines=[],
                ),
            ]
        )

    # Handle function definitions -------------------------

    # **Implementation Notes**
    #
    # It is much harder to predict where exactly type comments will live
    # in function definitions than in Assign / For / With.
    #
    # As a result, we use two different patterns:
    # (A) we aggressively strip out type comments from whitespace between the
    #     start of a function define and the start of the body, whenever we were
    #     able to extract type information. This is done via mutable state and the
    #     usual visitor pattern.
    # (B) we also manually reach down to the first statement inside of the
    #     funciton body and aggressively strip type comments from leading
    #     whitespaces
    #
    # PEP 484 underspecifies how to apply type comments to (non-static)
    # methods - it would be possible to provide a type for `self`, or to omit
    # it. So we accept either approach when interpreting type comments on
    # non-static methods: the first argument an have a type provided or not.

    def _visit_FunctionDef(
        self,
        node: cst.FunctionDef,
        is_method: bool,
    ) -> None:
        """
        Set up the data we need to handle function definitions:
        - Parse the type comments.
        - Store the resulting function type info on the stack, where it will
          remain until we use it in `leave_FunctionDef`
        - Set that we are aggressively stripping type comments, which will
          remain true until we visit the body.
        """
        function_type_info = FunctionTypeInfo.from_cst(node, is_method=is_method)
        self.aggressively_strip_type_comments = not function_type_info.is_empty()
        self.function_type_info_stack.append(function_type_info)
        self.function_body_stack.append(node.body)

    @m.call_if_not_inside(m.ClassDef())
    @m.visit(m.FunctionDef())
    def visit_method(
        self,
        node: cst.FunctionDef,
    ) -> None:
        return self._visit_FunctionDef(
            node=node,
            is_method=False,
        )

    @m.call_if_inside(m.ClassDef())
    @m.visit(m.FunctionDef())
    def visit_function(
        self,
        node: cst.FunctionDef,
    ) -> None:
        return self._visit_FunctionDef(
            node=node,
            is_method=not any(
                m.matches(d.decorator, m.Name("staticmethod")) for d in node.decorators
            ),
        )

    def leave_TrailingWhitespace(
        self,
        original_node: cst.TrailingWhitespace,
        updated_node: cst.TrailingWhitespace,
    ) -> Union[cst.TrailingWhitespace]:
        "Aggressively remove type comments when in header if we extracted types."
        if self.aggressively_strip_type_comments and _is_type_comment(
            updated_node.comment
        ):
            return cst.TrailingWhitespace()
        else:
            return updated_node

    def leave_EmptyLine(
        self,
        original_node: cst.EmptyLine,
        updated_node: cst.EmptyLine,
    ) -> Union[cst.EmptyLine, cst.RemovalSentinel]:
        "Aggressively remove type comments when in header if we extracted types."
        if self.aggressively_strip_type_comments and _is_type_comment(
            updated_node.comment
        ):
            return cst.RemovalSentinel.REMOVE
        else:
            return updated_node

    def visit_FunctionDef_body(
        self,
        node: cst.FunctionDef,
    ) -> None:
        "Turn off aggressive type comment removal when we've leaved the header."
        self.aggressively_strip_type_comments = False

    def leave_IndentedBlock(
        self,
        original_node: cst.IndentedBlock,
        updated_node: cst.IndentedBlock,
    ) -> cst.IndentedBlock:
        "When appropriate, strip function type comment from the function body."
        # abort unless this is the body of a function we are transforming
        if len(self.function_body_stack) == 0:
            return updated_node
        if original_node is not self.function_body_stack[-1]:
            return updated_node
        if self.function_type_info_stack[-1].is_empty():
            return updated_node
        # The comment will be in the body header if it was on the same line
        # as the colon.
        if _is_type_comment(updated_node.header.comment):
            updated_node = updated_node.with_changes(
                header=cst.TrailingWhitespace(),
            )
        # The comment will be in a leading line of the first body statement
        # if it was on the first line after the colon.
        first_statement = updated_node.body[0]
        if not hasattr(first_statement, "leading_lines"):
            return updated_node
        return updated_node.with_changes(
            body=[
                first_statement.with_changes(
                    leading_lines=[
                        line
                        # pyre-ignore[16]: we refined via `hasattr`
                        for line in first_statement.leading_lines
                        if not _is_type_comment(line.comment)
                    ]
                ),
                *updated_node.body[1:],
            ]
        )

    # Methods for adding type annotations ----
    #
    # By the time we get here, all type comments should already be stripped.

    def leave_Param(
        self,
        original_node: cst.Param,
        updated_node: cst.Param,
    ) -> cst.Param:
        # ignore type comments if there's already an annotation
        if updated_node.annotation is not None:
            return updated_node
        # find out if there's a type comment and apply it if so
        function_type_info = self.function_type_info_stack[-1]
        raw_annotation = function_type_info.arguments.get(updated_node.name.value)
        if raw_annotation is not None:
            return updated_node.with_changes(
                annotation=_convert_annotation(
                    raw=raw_annotation,
                    quote_annotations=self.quote_annotations,
                )
            )
        else:
            return updated_node

    def leave_FunctionDef(
        self,
        original_node: cst.FunctionDef,
        updated_node: cst.FunctionDef,
    ) -> cst.FunctionDef:
        self.function_body_stack.pop()
        function_type_info = self.function_type_info_stack.pop()
        if updated_node.returns is None and function_type_info.returns is not None:
            return updated_node.with_changes(
                returns=_convert_annotation(
                    raw=function_type_info.returns,
                    quote_annotations=self.quote_annotations,
                )
            )
        else:
            return updated_node

    def visit_Lambda(
        self,
        node: cst.Lambda,
    ) -> bool:
        """
        Disable traversing under lambdas. They don't have any statements
        nested inside them so there's no need, and they do have Params which
        we don't want to transform.
        """
        return False