예제 #1
0
 def test_ignored_lines(self, *, source: str, ignored_code: str,
                        ignored_lines: Container[int]) -> None:
     tokens = tuple(
         tokenize.tokenize(BytesIO(source.encode("utf-8")).readline))
     ignore_info = IgnoreInfo.compute(
         comment_info=CommentInfo.compute(tokens=tokens),
         line_mapping_info=LineMappingInfo.compute(tokens=tokens),
     )
     lines = range(1, tokens[-1].end[0] + 1)
     actual_ignored_lines = []
     for line in lines:
         ignored = ignore_info.should_ignore_report(
             CstLintRuleReport(
                 file_path=Path("fake/path.py"),
                 node=cst.EmptyLine(),
                 code=ignored_code,
                 message="message",
                 line=line,
                 column=0,
                 module=cst.MetadataWrapper(cst.parse_module(source)),
                 module_bytes=source.encode("utf-8"),
             ))
         if ignored:
             actual_ignored_lines.append(line)
     # pyre-fixme[6]: Expected `Iterable[Variable[_T]]` for 1st param but got
     #  `Container[int]`.
     self.assertEqual(actual_ignored_lines, list(ignored_lines))
예제 #2
0
 def target_suggestions(self):
     with ctx_inliner.set(self):
         globls = Tracer(self.module,
                         globls=self.base_globls).trace().globls
         collector = CollectTargetSuggestions(self, globls)
         cst.MetadataWrapper(self.module).visit(collector)
         return collector.suggestions
예제 #3
0
    def test_lambda_metadata_matcher(self) -> None:
        # Match on qualified name provider
        module = cst.parse_module(
            "from typing import List\n\ndef foo() -> None: pass\n")
        wrapper = cst.MetadataWrapper(module)
        functiondef = cst.ensure_type(wrapper.module.body[1], cst.FunctionDef)

        self.assertTrue(
            matches(
                functiondef,
                m.FunctionDef(name=m.MatchMetadataIfTrue(
                    meta.QualifiedNameProvider,
                    lambda qualnames: any(n.name in {"foo", "bar", "baz"}
                                          for n in qualnames),
                )),
                metadata_resolver=wrapper,
            ))

        self.assertFalse(
            matches(
                functiondef,
                m.FunctionDef(name=m.MatchMetadataIfTrue(
                    meta.QualifiedNameProvider,
                    lambda qualnames: any(n.name in {"bar", "baz"}
                                          for n in qualnames),
                )),
                metadata_resolver=wrapper,
            ))
예제 #4
0
def main():

    options = get_arg_parser().parse_args()

    root_file_path = 'delinter/test/input/test_unused_imports.py'
    msg_template = r'{path}:{line}:[{msg_id}({symbol}),{obj}]{msg}'
    pylint_command = f"{root_file_path} --enable=W --disable=C,R,E,F --msg-template={msg_template} --score=n"

    out, _ = lint.py_run(pylint_command, return_std=True)
    result = "".join(out.readlines()).split('\n')
    result = [r.strip() for r in result if r.strip() and not r.strip().
            startswith('************* Module ')]
    parsed_warnings = Delinter.parse_linter_warnings(result, options.msg_id)
    if os.path.isdir(root_file_path):
        from pathlib import Path
        files = Path(root_file_path).glob('**/*.py')
    else:
        files = [root_file_path]

    for file_path in files:
        with open(file_path) as f:
            source_code = "".join(f.readlines())
            source_tree = cst.parse_module(source_code)
            wrapper = cst.MetadataWrapper(source_tree)
            local_warnings = [p for p in parsed_warnings if p.file_path == str(file_path)]
            fixed_module = wrapper.visit(
                    SUPPORTED_LINTER_MAP[options.msg_id][1](local_warnings))
            a_file_path = 'a/' + str(file_path)
            b_file_path = 'b/' + str(file_path)
            print("".join(difflib.unified_diff(
                    source_code.splitlines(1),
                    fixed_module.code.splitlines(1),
                    fromfile=a_file_path,
                    tofile=b_file_path
                    )))
예제 #5
0
 def get_refs(mod):
     analyzer = AnalyzeCalls()
     cst.MetadataWrapper(mod).visit(analyzer)
     return [
         {'start': span.start, 'end': span.start + span.length, 'name': name}
         for (name, span) in analyzer.calls
     ]
예제 #6
0
 def code_folding(self):
     tracer = Tracer(self.module,
                     globls=self.base_globls,
                     args=TracerArgs(trace_lines=True)).trace()
     finder = FindUnexecutedBlocks(tracer)
     cst.MetadataWrapper(self.module, unsafe_skip_copy=True).visit(finder)
     return sorted(finder.unexecuted)
예제 #7
0
class LintRuleReportTest(UnitTest):
    @data_provider(
        {
            "AstLintRuleReport": [
                AstLintRuleReport(
                    file_path=Path("fake/path.py"),
                    node=ast.parse(""),
                    code="SomeFakeRule",
                    message="some message",
                    line=1,
                    column=1,
                )
            ],
            "CstLintRuleReport": [
                CstLintRuleReport(
                    file_path=Path("fake/path.py"),
                    node=cst.parse_statement("pass\n"),
                    code="SomeFakeRule",
                    message="some message",
                    line=1,
                    column=1,
                    module=cst.MetadataWrapper(cst.parse_module(b"pass\n")),
                    module_bytes=b"pass\n",
                )
            ],
        }
    )
    def test_is_not_pickleable(self, report: BaseLintRuleReport) -> None:
        with pytest.raises(pickle.PicklingError):
            pickle.dumps(report)
예제 #8
0
    def test(
        self,
        *,
        source: bytes,
        rules_in_lint_run: Collection[Type[CstLintRule]],
        rules_without_report: Collection[Type[CstLintRule]],
        suppressed_line: int,
        expected_unused_suppressions_report_messages: Collection[str],
        expected_replacements: Optional[List[str]] = None,
    ) -> None:
        reports = [
            CstLintRuleReport(
                file_path=FILE_PATH,
                node=cst.EmptyLine(),
                code=rule.__name__,
                message="message",
                line=suppressed_line,
                column=0,
                module=cst.MetadataWrapper(cst.parse_module(source)),
                module_bytes=source,
            ) for rule in rules_in_lint_run if rule not in rules_without_report
        ]
        tokens = _get_tokens(source)
        ignore_info = IgnoreInfo.compute(
            comment_info=CommentInfo.compute(tokens=tokens),
            line_mapping_info=LineMappingInfo.compute(tokens=tokens),
        )
        cst_wrapper = MetadataWrapper(cst.parse_module(source),
                                      unsafe_skip_copy=True)
        config = LintConfig(
            rule_config={
                RemoveUnusedSuppressionsRule.__name__: {
                    "ignore_info": ignore_info,
                    "rules": rules_in_lint_run,
                }
            })
        unused_suppressions_context = CstContext(cst_wrapper, source,
                                                 FILE_PATH, config)
        for report in reports:
            ignore_info.should_ignore_report(report)
        _visit_cst_rules_with_context(cst_wrapper,
                                      [RemoveUnusedSuppressionsRule],
                                      unused_suppressions_context)

        messages = []
        patches = []
        for report in unused_suppressions_context.reports:
            messages.append(report.message)
            patches.append(report.patch)

        self.assertEqual(messages,
                         expected_unused_suppressions_report_messages)
        if expected_replacements is None:
            self.assertEqual(len(patches), 0)
        else:
            self.assertEqual(len(patches), len(expected_replacements))

            for idx, patch in enumerate(patches):
                replacement = patch.apply(source.decode())
                self.assertEqual(replacement, expected_replacements[idx])
예제 #9
0
    def test_replace_metadata(self) -> None:
        def _rename_foo(
            node: cst.CSTNode,
            extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
        ) -> cst.CSTNode:
            return cst.ensure_type(node,
                                   cst.Name).with_changes(value="replaced")

        original = cst.parse_module(
            "foo: int = 37\ndef bar(foo: int) -> int:\n    return foo\n\nbiz: int = bar(42)\n"
        )
        wrapper = cst.MetadataWrapper(original)
        replaced = cst.ensure_type(
            m.replace(
                wrapper,
                m.Name(metadata=m.MatchMetadataIfTrue(
                    meta.QualifiedNameProvider,
                    lambda qualnames: any(n.name == "foo" for n in qualnames),
                )),
                _rename_foo,
            ),
            cst.Module,
        ).code
        self.assertEqual(
            replaced,
            "replaced: int = 37\ndef bar(foo: int) -> int:\n    return foo\n\nbiz: int = bar(42)\n",
        )
예제 #10
0
def collect_names(tree, ctx=()):
    """
    Convenience wrapper for NameCollector
    """
    visitor = NameCollector(ctx)
    wrapper = cst.MetadataWrapper(tree, unsafe_skip_copy=True)
    wrapper.visit(visitor)
    return visitor.names
예제 #11
0
    def _parse(self, raw):
        try:
            cst = libcst.parse_module(raw)
            annotated_cst = libcst.MetadataWrapper(cst)
        except libcst._exceptions.ParserSyntaxError as error:
            # TODO: log something? What's the wanted behavior here?
            raise error

        return annotated_cst
예제 #12
0
def refactor_string(
    source: str,
    unused_imports: List[Union[Import, ImportFrom]],
) -> str:
    if unused_imports:
        wrapper = cst.MetadataWrapper(cst.parse_module(source))
        fixed_module = wrapper.visit(
            RemoveUnusedImportTransformer(unused_imports))
        return fixed_module.code
    return source
예제 #13
0
def test_assertEquals_02():
    source = """
def test_1():
    assertEquals(True, False)
"""
    module = cst.parse_module(source)
    wrapper = cst.MetadataWrapper(module)
    checker = Checker(Path("(test)"))
    wrapper.visit(checker)
    assert not checker.errors
예제 #14
0
def ambiguous_tuple_equality_example():
    codes = ["True, True, True == (True, True, True)",
             "True, True, (True == (True, True, True))",
             "(True, True, True) == (True, True, True)", ]

    for code in codes:
        print(code)
        node = cst.parse_module(code)
        node_with_metadata = cst.MetadataWrapper(node)
        # print(dump(node.body[0].body[0].value, show_syntax=True, indent=" " * 4, show_whitespace=True, show_defaults=True))
        node_with_metadata.visit(Visitor())
예제 #15
0
def test_assertEquals_01():
    source = """
class MyTestCase(unittest.TestCase):
    def test_1():
        self.assertEquals(True, False)
"""
    module = cst.parse_module(source)
    wrapper = cst.MetadataWrapper(module)
    checker = Checker(Path("(test)"))
    wrapper.visit(checker)
    assert checker.errors
예제 #16
0
 def _make_fixture(
         self,
         code: str) -> Tuple[cst.BaseExpression, meta.MetadataWrapper]:
     module = cst.parse_module(dedent(code))
     wrapper = cst.MetadataWrapper(module)
     return (
         cst.ensure_type(
             cst.ensure_type(wrapper.module.body[0],
                             cst.SimpleStatementLine).body[0],
             cst.Expr,
         ).value,
         wrapper,
     )
예제 #17
0
def check_result(source, expected):
    module = cst.parse_module(source)
    wrapper = cst.MetadataWrapper(module)
    modernizer = Modernizer(Path("(test)"))
    modified_tree = wrapper.visit(modernizer)
    diff = "".join(
        difflib.unified_diff(
            expected.splitlines(True),
            modified_tree.code.splitlines(True),
            fromfile="expected",
            tofile="actual",
        )
    )
    assert modified_tree.code == expected, diff
예제 #18
0
    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree
예제 #19
0
def get_fully_qualified_names(file_path: str, module_str: str) -> Set[QualifiedName]:
    wrapper = cst.MetadataWrapper(
        cst.parse_module(dedent(module_str)),
        cache={
            FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
                Path(""), [file_path], None
            ).get(file_path, "")
        },
    )
    return {
        qname
        for qnames in wrapper.resolve(FullyQualifiedNameProvider).values()
        for qname in qnames
    }
    def generate_trial(self, program, cond):
        metadata = json.loads(program.header[0].comment.value[1:])
        name = program.body[0].name.value

        program = cst.MetadataWrapper(program).visit(Preprocessor())

        if cond == self.Condition.Random:
            program = cst.MetadataWrapper(program).visit(RandomRenamer())
 
        call = f'mystery({metadata["input"]})'

        globls = {}
        exec(program.code, globls, globls)
        answer = eval(call, globls, globls)

        return {
            'program': program.code,
            'call': call,
            'function': name,
            'cond': str(cond),
            'answer': str(answer),
            'schema': metadata['schema']
        }
예제 #21
0
    def test_batchable_provider(self) -> None:
        test = self

        class SomeVisitor(cst.BatchableCSTVisitor):
            METADATA_DEPENDENCIES = (ByteSpanPositionProvider,)

            def visit_Pass(self, node: cst.Pass) -> None:
                test.assertEqual(
                    self.get_metadata(ByteSpanPositionProvider, node),
                    CodeSpan(start=0, length=4),
                )

        wrapper = cst.MetadataWrapper(cst.parse_module("pass"))
        wrapper.visit_batched([SomeVisitor()])
예제 #22
0
    def unused_vars(self):
        assert self.trace_reads, "Tracer was not executed with trace_reads=True"
        visitor = UnusedVarsVisitor(self)

        # unsafe_skip_copy to ensure that nodes in map are pointer-equivalent
        # to input mod
        wrapper = cst.MetadataWrapper(self.transformed_module,
                                      unsafe_skip_copy=True)
        wrapper.visit(visitor)

        unused_vars = visitor.unused_vars
        return {
            self.node_map[k]: v
            for k, v in unused_vars.items() if k in self.node_map
        }
예제 #23
0
 def setUp(self) -> None:
     self.fake_filepath = Path("fake/path.py")
     self.report = CstLintRuleReport(
         file_path=self.fake_filepath,
         node=cst.parse_statement("pass\n"),
         code="SomeFakeRule",
         message=(
             "Some long message that should span multiple lines.\n" + "\n" +
             "Another paragraph with more information about the lint rule."
         ),
         line=1,
         column=1,
         module=cst.MetadataWrapper(cst.parse_module(b"pass\n")),
         module_bytes=b"pass\n",
     )
예제 #24
0
def coverage_collector_for_module(relative_path: str, module: cst.Module,
                                  strict_default: bool) -> CoverageCollector:
    module_with_metadata = cst.MetadataWrapper(module)
    strict_count_collector = StrictCountCollector(strict_default)
    try:
        module_with_metadata.visit(strict_count_collector)
    except RecursionError:
        LOG.warning(f"LibCST encountered recursion error in `{relative_path}`")
    coverage_collector = CoverageCollector(
        strict_count_collector.is_strict_module())
    try:
        module_with_metadata.visit(coverage_collector)
    except RecursionError:
        LOG.warning(f"LibCST encountered recursion error in `{relative_path}`")
    return coverage_collector
예제 #25
0
def main() -> Optional[int]:
    parser = argparse.ArgumentParser(description="Test things.")
    parser.add_argument("file", nargs="+")
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose output.")
    parser.add_argument("-q", "--quiet", action="store_true", help="no output.")
    parser.add_argument(
        "-x",
        "--exitfirst",
        action="store_true",
        help="exit instantly on first error or failed test.",
    )
    parser.add_argument("--ignore", nargs="*", help="errors to ignore.")
    args = parser.parse_args()
    paths = expand_paths([Path(name).expanduser() for name in args.file])
    errors = False
    for path in paths:
        if path.is_dir() or path.suffix != ".py":
            continue
        if args.verbose:
            print(f"Checking {path}")
        py_source = path.read_text()
        module = cst.parse_module(py_source)
        wrapper = cst.MetadataWrapper(module)
        checker = Checker(path, args.verbose, args.ignore)
        wrapper.visit(checker)
        if checker.errors:
            if args.exitfirst:
                return 1
            errors = True
        modernizer = Modernizer(path, args.verbose, args.ignore)
        modified_tree = wrapper.visit(modernizer)
        if modernizer.errors:
            if args.exitfirst:
                return 1
            errors = True
        if not args.quiet:
            diff = "".join(
                difflib.unified_diff(
                    py_source.splitlines(True),
                    modified_tree.code.splitlines(True),
                    fromfile=f"a{path}",
                    tofile=f"b{path}",
                )
            )
            if diff:
                print(diff)
    if errors:
        return 1
예제 #26
0
    def test_extract_metadata(self) -> None:
        # Verify true behavior
        module = cst.parse_module("a + b[c], d(e, f * g)")
        wrapper = cst.MetadataWrapper(module)
        expression = cst.ensure_type(
            cst.ensure_type(wrapper.module.body[0],
                            cst.SimpleStatementLine).body[0],
            cst.Expr,
        ).value

        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 1)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 2)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        self.assertIsNone(nodes)
예제 #27
0
def _run_delinter(options):
    '''
    Run the delinter and produce the diff.
    '''
    root_file_path = options.file_path_or_folder
    # TODO: Handle Windows paths
    if Path(root_file_path).is_absolute():
        msg_template = r'{abspath}:{line}:[{msg_id}({symbol}),{obj}]{msg}'
        sep = ''
    else:
        msg_template = r'{path}:{line}:[{msg_id}({symbol}),{obj}]{msg}'
        sep = '/'

    pylint_command = f"{root_file_path} --enable=W --disable=C,R,E,F --msg-template={msg_template} --score=n"

    out, _ = lint.py_run(pylint_command, return_std=True)
    orig_result = "".join(out.readlines()).split('\n')
    result = [
        r.strip() for r in orig_result
        if r.strip() and not r.strip().startswith('************* Module ')
    ]
    parsed_warnings = Delinter.parse_linter_warnings(result, options.msg_id)
    if os.path.isdir(root_file_path):
        files = list(Path(root_file_path).glob('**/*.py'))
    else:
        files = [root_file_path]

    for file_path in files:
        with open(file_path) as f:
            source_code = "".join(f.readlines())
            if not source_code:
                continue
            source_tree = cst.parse_module(source_code)
            wrapper = cst.MetadataWrapper(source_tree)
            local_warnings = [
                p for p in parsed_warnings if p.file_path == str(file_path)
            ]
            fixed_module = wrapper.visit(
                SUPPORTED_LINTER_MAP[options.msg_id][1](local_warnings))
            a_file_path = f'a{sep}{file_path}'
            b_file_path = f'b{sep}{file_path}'
            result = "".join(
                difflib.unified_diff(source_code.splitlines(1),
                                     fixed_module.code.splitlines(1),
                                     fromfile=a_file_path,
                                     tofile=b_file_path))
            if result:
                print(result)
예제 #28
0
    def exec_counts(self) -> ExecCounts:
        assert self.trace_lines, "Tracer was not executed with trace_lines=True"

        visitor = ExecCountsVisitor(self)

        # unsafe_skip_copy to ensure that nodes in map are pointer-equivalent
        # to input mod
        wrapper = cst.MetadataWrapper(self.transformed_module,
                                      unsafe_skip_copy=True)
        wrapper.visit(visitor)

        exec_counts = visitor.exec_counts
        return {
            self.node_map[k]: v
            for k, v in exec_counts.items() if k in self.node_map
        }
예제 #29
0
    def test_pylint_warning(self):
        warnings = unused_import_warnings.split('\n')
        warnings = [w for w in warnings if w]
        parsed_warnings = Delinter.parse_linter_warnings(warnings)

        source_tree = cst.parse_module(source_code)
        wrapper = cst.MetadataWrapper(source_tree)
        fixed_module = wrapper.visit(
            unused_imports.RemoveUnusedImportTransformer(parsed_warnings))
        diff = "".join(
            difflib.unified_diff(source_code.splitlines(1),
                                 fixed_module.code.splitlines(1)))

        diff = diff.replace('+++ ', '+++').replace('--- ',
                                                   '---').replace('\n ', '\n')
        new_expected_diff = expected_diff.replace('\n ', '\n')
        self.assertEqual(diff, new_expected_diff)
예제 #30
0
def refactor_string(
    source: str,
    unused_imports: List[Union[Import, ImportFrom]],
    show_error: bool = False,
) -> str:
    try:
        wrapper = cst.MetadataWrapper(cst.parse_module(source))
    except cst.ParserSyntaxError as err:
        if show_error:
            print(Color(str(err)).red)
    else:
        if unused_imports:
            fixed_module = wrapper.visit(
                RemoveUnusedImportTransformer(unused_imports)
            )
            return fixed_module.code
    return source