def transform_module_impl(self, tree: cst.Module) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_import_names(import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY
        )
        if context_contents:
            stub, overwrite_existing_annotations = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations or overwrite_existing_annotations
            )
            visitor = TypeCollector(existing_import_names, self.context)
            stub.visit(visitor)
            self.annotations.function_annotations.update(visitor.function_annotations)
            self.annotations.attribute_annotations.update(visitor.attribute_annotations)
            self.annotations.class_definitions.update(visitor.class_definitions)

        tree_with_imports = AddImportsVisitor(self.context).transform_module(tree)
        return tree_with_imports.visit(self)
示例#2
0
    def assertCodemod(
        self,
        before: str,
        after: str,
        *args: object,
        context_override: Optional[CodemodContext] = None,
        python_version: Optional[str] = None,
        expected_warnings: Optional[Sequence[str]] = None,
        expected_skip: bool = False,
        **kwargs: object,
    ) -> None:
        """
        Given a before and after code string, and any args/kwargs that should
        be passed to the codemod constructor specified in
        :attr:`~CodemodTest.TRANSFORM`, validate that the codemod executes as
        expected. Verify that the codemod completes successfully, unless the
        ``expected_skip`` option is set to ``True``, in which case verify that
        the codemod skips.  Optionally, a :class:`CodemodContext` can be provided.
        If none is specified, a default, empty context is created for you.
        Additionally, the python version for the code parser can be overridden
        to a valid python version string such as `"3.6"`. If none is specified,
        the version of the interpreter running your tests will be used. Also, a
        list of warning strings can be specified and :meth:`~CodemodTest.assertCodemod`
        will verify that the codemod generates those warnings in the order
        specified. If it is left out, warnings are not checked.
        """

        context = context_override if context_override is not None else CodemodContext(
        )
        # pyre-fixme[45]: Cannot instantiate abstract class `Codemod`.
        transform_instance = self.TRANSFORM(context, *args, **kwargs)
        input_tree = parse_module(
            CodemodTest.make_fixture_data(before),
            config=(PartialParserConfig(python_version=python_version)
                    if python_version is not None else PartialParserConfig()),
        )
        try:
            output_tree = transform_instance.transform_module(input_tree)
        except SkipFile:
            if not expected_skip:
                raise
            output_tree = input_tree
        else:
            if expected_skip:
                # pyre-ignore This mixin needs to be used with a UnitTest subclass.
                self.fail("Expected SkipFile but was not raised")
        # pyre-ignore This mixin needs to be used with a UnitTest subclass.
        self.assertEqual(
            CodemodTest.make_fixture_data(after),
            CodemodTest.make_fixture_data(output_tree.code),
        )
        if expected_warnings is not None:
            # pyre-ignore This mixin needs to be used with a UnitTest subclass.
            self.assertSequenceEqual(expected_warnings, context.warnings)
    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree
示例#4
0
    def assertCodemod(
        self,
        before: str,
        after: str,
        *args: object,
        context_override: Optional[CodemodContext] = None,
        python_version: str = "3.7",
        expected_warnings: Optional[Sequence[str]] = None,
        expected_skip: bool = False,
        **kwargs: object,
    ) -> None:
        """
        Given a before and after string, and optionally any args/kwargs that
        should be passed to the codemod visitor constructor, validate that
        the codemod executes as expected.
        """

        context = context_override if context_override is not None else CodemodContext(
        )
        transform_instance = self.TRANSFORM(context, *args, **kwargs)
        input_tree = parse_module(
            CodemodTest.make_fixture_data(before),
            config=PartialParserConfig(python_version=python_version),
        )
        try:
            output_tree = transform_instance.transform_module(input_tree)
        except SkipFile:
            if not expected_skip:
                raise
            output_tree = input_tree
        else:
            if expected_skip:
                # pyre-ignore This mixin needs to be used with a UnitTest subclass.
                self.fail("Expected SkipFile but was not raised")
        # pyre-ignore This mixin needs to be used with a UnitTest subclass.
        self.assertEqual(CodemodTest.make_fixture_data(after),
                         output_tree.code)
        if expected_warnings is not None:
            # pyre-ignore This mixin needs to be used with a UnitTest subclass.
            self.assertSequenceEqual(expected_warnings, context.warnings)