Example #1
0
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
                                  updated_node: cst.SimpleStatementLine):
        if match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.Assign(targets=[
                        match.AssignTarget(target=match.Name(
                            value=match.DoNotCare()))
                    ])
                ])):
            t = self.__get_var_type_assign_t(
                original_node.body[0].targets[0].target.value)

            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(
                            target=original_node.body[0].targets[0].target,
                            value=original_node.body[0].value,
                            annotation=t_annot_node,
                            equal=cst.AssignEqual(
                                whitespace_after=original_node.body[0].
                                targets[0].whitespace_after_equal,
                                whitespace_before=original_node.body[0].
                                targets[0].whitespace_before_equal))
                    ])
        elif match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.AnnAssign(target=match.Name(value=match.DoNotCare()))
                ])):
            t = self.__get_var_type_an_assign(
                original_node.body[0].target.value)
            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(target=original_node.body[0].target,
                                      value=original_node.body[0].value,
                                      annotation=t_annot_node,
                                      equal=original_node.body[0].equal)
                    ])

        return original_node
Example #2
0
    def _split_module(
        self, orig_module: libcst.Module, updated_module: libcst.Module
    ) -> Tuple[List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]],
               List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]], List[Union[
                              libcst.SimpleStatementLine,
                              libcst.BaseCompoundStatement]], ]:
        statement_before_import_location = 0
        import_add_location = 0

        # never insert an import before initial __strict__ flag
        if m.matches(
                orig_module,
                m.Module(body=[
                    m.SimpleStatementLine(body=[
                        m.Assign(targets=[
                            m.AssignTarget(target=m.Name("__strict__"))
                        ])
                    ]),
                    m.ZeroOrMore(),
                ]),
        ):
            statement_before_import_location = import_add_location = 1

        # This works under the principle that while we might modify node contents,
        # we have yet to modify the number of statements. So we can match on the
        # original tree but break up the statements of the modified tree. If we
        # change this assumption in this visitor, we will have to change this code.
        for i, statement in enumerate(orig_module.body):
            if m.matches(
                    statement,
                    m.SimpleStatementLine(
                        body=[m.Expr(value=m.SimpleString())])):
                statement_before_import_location = import_add_location = 1
            elif isinstance(statement, libcst.SimpleStatementLine):
                for possible_import in statement.body:
                    for last_import in self.all_imports:
                        if possible_import is last_import:
                            import_add_location = i + 1
                            break

        return (
            list(updated_module.body[:statement_before_import_location]),
            list(updated_module.
                 body[statement_before_import_location:import_add_location]),
            list(updated_module.body[import_add_location:]),
        )
        class RemoveBarTransformer(VisitorBasedCodemodCommand):

            METADATA_DEPENDENCIES = (QualifiedNameProvider, ScopeProvider)

            @m.leave(
                m.SimpleStatementLine(body=[
                    m.Expr(
                        m.Call(metadata=m.MatchMetadata(
                            QualifiedNameProvider,
                            {
                                QualifiedName(
                                    source=QualifiedNameSource.IMPORT,
                                    name="foo.bar",
                                )
                            },
                        )))
                ]))
            def _leave_foo_bar(
                self,
                original_node: cst.SimpleStatementLine,
                updated_node: cst.SimpleStatementLine,
            ) -> cst.RemovalSentinel:
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)
                return cst.RemoveFromParent()
Example #4
0
    def on_leave(self, original_node, updated_node):
        final_node = super().on_leave(original_node, updated_node)

        if (isinstance(final_node, cst.BaseStatement) and not m.matches(
                final_node,
                m.SimpleStatementLine(body=[m.Expr(m.SimpleString())]))
                and self.exec_counts[original_node] == 0):
            return cst.RemoveFromParent()

        return final_node
Example #5
0
 def _is_awaitable_callable(annotation: str) -> bool:
     if not (annotation.startswith("typing.Callable")
             or annotation.startswith("typing.ClassMethod")
             or annotation.startswith("StaticMethod")):
         # Exit early if this is not even a `typing.Callable` annotation.
         return False
     try:
         # Wrap this in a try-except since the type annotation may not be parse-able as a module.
         # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`.
         parsed_ann = cst.parse_module(annotation)
     except Exception:
         return False
     # If passed annotation does not match the expected annotation structure for a `typing.Callable` with
     # typing.Coroutine as the return type, matched_callable_ann will simply be `None`.
     # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]]
     matched_callable_ann: Optional[Dict[str, Union[
         Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract(
             parsed_ann,
             m.Module(body=[
                 m.SimpleStatementLine(body=[
                     m.Expr(value=m.Subscript(slice=[
                         m.SubscriptElement(),
                         m.SubscriptElement(slice=m.Index(value=m.Subscript(
                             value=m.SaveMatchedNode(
                                 m.Attribute(),
                                 "base_return_type",
                             )))),
                     ], ))
                 ]),
             ]),
         )
     if (matched_callable_ann is not None
             and "base_return_type" in matched_callable_ann):
         base_return_type = get_full_name_for_node(
             cst.ensure_type(matched_callable_ann["base_return_type"],
                             cst.CSTNode))
         return (base_return_type is not None
                 and base_return_type == "typing.Coroutine")
     return False
Example #6
0
    def on_leave(self, old_node, new_node):
        new_node = super().on_leave(old_node, new_node)

        if isinstance(new_node, self.block_types):
            cur_stmts = new_node.body

            any_change = False
            while True:
                change = False
                N = len(cur_stmts)
                for i in reversed(range(N)):
                    stmt = cur_stmts[i]
                    is_return = m.matches(stmt,
                                          m.SimpleStatementLine([m.Return()]))
                    is_return_block = isinstance(stmt, cst.BaseCompoundStatement) and \
                        stmt.body in self.return_blocks

                    if is_return or is_return_block:
                        change = True
                        any_change = True
                        [cur_stmts, block] = [cur_stmts[:i], cur_stmts[i + 1:]]

                        if is_return_block:
                            self.return_blocks.remove(stmt.body)
                            cur_stmts.append(stmt)

                        if i < N - 1:
                            cur_stmts.append(self._build_if(block))

                        break

                if not change:
                    break

            new_node = new_node.with_changes(body=cur_stmts)
            if any_change:
                self.return_blocks.add(new_node)
        return new_node
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        doc_string = node.get_docstring()
        if not doc_string or "@sorted-attributes" not in doc_string:
            return

        found_any_assign: bool = False
        pre_assign_lines: List[LineType] = []
        assign_lines: List[LineType] = []
        post_assign_lines: List[LineType] = []

        def _add_unmatched_line(line: LineType) -> None:
            post_assign_lines.append(
                line) if found_any_assign else pre_assign_lines.append(line)

        for line in node.body.body:
            if m.matches(
                    line,
                    m.SimpleStatementLine(
                        body=[m.Assign(targets=[m.AssignTarget()])])):
                found_any_assign = True
                assign_lines.append(line)
            else:
                _add_unmatched_line(line)
                continue

        sorted_assign_lines = sorted(
            assign_lines,
            key=lambda line: line.body[0].targets[0].target.value)
        if sorted_assign_lines == assign_lines:
            return
        self.report(
            node,
            replacement=node.with_changes(body=node.body.with_changes(
                body=pre_assign_lines + sorted_assign_lines +
                post_assign_lines)),
        )
def _is_import_line(
        line: Union[cst.SimpleStatementLine,
                    cst.BaseCompoundStatement]) -> bool:
    return m.matches(line,
                     m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()]))
Example #9
0
class Modernizer(m.MatcherDecoratableTransformer):
    METADATA_DEPENDENCIES = (PositionProvider,)
    # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node

    def __init__(
        self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None
    ):
        super().__init__()
        self.path = path
        self.verbose = verbose
        self.ignored = set(ignored or [])
        self.errors = False
        self.stack: List[Tuple[str, ...]] = []
        self.annotations: Dict[
            Tuple[str, ...], Comment  # key: tuple of canonical variable name
        ] = {}
        self.python_future_updated_node: Optional[SimpleStatementLine] = None
        self.python_future_imports: Dict[str, str] = {}
        self.python_future_new_imports: Set[str] = set()
        self.builtins_imports: Dict[str, str] = {}
        self.builtins_new_imports: Set[str] = set()
        self.builtins_updated_node: Optional[SimpleStatementLine] = None
        self.future_utils_imports: Dict[str, str] = {}
        self.future_utils_new_imports: Set[str] = set()
        self.future_utils_updated_node: Optional[SimpleStatementLine] = None
        # self.last_import_node: Optional[CSTNode] = None
        self.last_import_node_stmt: Optional[CSTNode] = None

    # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__")))
    # @m.visit(m.ImportAlias() | m.ImportStar())
    # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None:
    #     self.add_import(self.python_future_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("__future__")))
    # def import_python_future_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(m.ImportFrom(module=m.Name("builtins")))
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.builtins_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("builtins")))
    # def builtins_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(
        m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    )
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.future_utils_imports, node)

    # @m.leave(
    #     m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    # )
    # def future_utils_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @staticmethod
    def add_import(
        imports: Dict[str, str], node: Union[ImportAlias, ImportStar]
    ) -> None:
        if isinstance(node, ImportAlias):
            imports[node.name.value] = (
                node.asname.name.value if node.asname else node.name.value
            )
        else:
            imports["*"] = "*"

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_Import(self, node: Import) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]:
        for n in node.body:
            if m.matches(n, m.Import() | m.ImportFrom()):
                self.last_import_node_stmt = node
        return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def leave_SimpleStatementLine(
        self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine
    ) -> Union[BaseStatement, RemovalSentinel]:
        for n in updated_node.body:
            if m.matches(n, m.ImportFrom(module=m.Name("__future__"))):
                self.python_future_updated_node = updated_node
            elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))):
                self.builtins_updated_node = updated_node
            elif m.matches(
                n,
                m.ImportFrom(
                    module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))
                ),
            ):
                self.future_utils_updated_node = updated_node
        return updated_node

    # @m.visit(
    #     m.AllOf(
    #         m.SimpleStatementLine(),
    #         m.MatchIfTrue(
    #             lambda node: any(m.matches(c, m.Assign()) for c in node.children)
    #         ),
    #         m.MatchIfTrue(
    #             lambda node: "# type:" in node.trailing_whitespace.comment.value
    #         ),
    #     )
    # )
    # def visit_assign(self, node: SimpleStatementSuite) -> None:
    #     return None

    def visit_Param(self, node: Param) -> Optional[bool]:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()
                self.ptype: Optional[str] = None

            def visit_TrailingWhitespace_comment(
                self, node: "TrailingWhitespace"
            ) -> None:
                if node.comment and "type:" in node.comment.value:
                    mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                    self.ptype = mo.group(1) if mo else None
                return None

        v = Visitor()
        node.visit(v)
        if self.verbose:
            pos = self.get_metadata(PositionProvider, node).start
            print(
                f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}"
            )
        return None

    @m.visit(m.SimpleStatementLine())
    def visit_simple_stmt(self, node: SimpleStatementLine) -> None:
        assign = None
        for c in node.children:
            if m.matches(c, m.Assign()):
                assign = ensure_type(c, Assign)
        if assign:
            if m.MatchIfTrue(
                lambda n: n.trailing_whitespace.comment
                and "type:" in n.trailing_whitespace.comment.value
            ):

                class TypingVisitor(m.MatcherDecoratableVisitor):
                    def __init__(self):
                        super().__init__()
                        self.vtype = None

                    def visit_TrailingWhitespace_comment(
                        self, node: "TrailingWhitespace"
                    ) -> None:
                        if node.comment:
                            mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                            if mo:
                                vtype = mo.group(1)
                        return None

                tv = TypingVisitor()
                node.visit(tv)
                vtype = tv.vtype
            else:
                vtype = None

            class NameVisitor(m.MatcherDecoratableVisitor):
                def __init__(self):
                    super().__init__()
                    self.names: List[str] = []

                def visit_Name(self, node: Name) -> Optional[bool]:
                    self.names.append(node.value)
                    return None

            if self.verbose:
                pos = self.get_metadata(PositionProvider, node).start
                for target in assign.targets:
                    v = NameVisitor()
                    target.visit(v)
                    for name in v.names:
                        print(
                            f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}"
                        )

    def visit_FunctionDef_body(self, node: FunctionDef) -> None:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()

            def visit_EmptyLine_comment(self, node: "EmptyLine") -> None:
                # FIXME too many matches on test_param_02
                if not node.comment:
                    return
                # TODO: use comment.value
                return None

        v = Visitor()
        node.visit(v)
        return None

    map_matcher = m.Call(
        func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range")
    )

    @m.visit(map_matcher)
    def visit_map(self, node: Call) -> None:
        func_name = ensure_type(node.func, Name).value
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(map_matcher)
    def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression:
        # TODO test with CompFor etc.
        # TODO improve join test
        func_name = ensure_type(updated_node.func, Name).value
        if func_name not in self.builtins_imports:
            updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def visit_xrange(self, node: Call) -> None:
        orig_func_name = ensure_type(node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression:
        orig_func_name = ensure_type(updated_node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        return updated_node.with_changes(func=Name(func_name))

    iter_matcher = m.Call(
        func=m.Attribute(
            attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems")
        )
    )

    @m.visit(iter_matcher)
    def visit_iter(self, node: Call) -> None:
        func_name = ensure_type(node.func, Attribute).attr.value
        if func_name not in self.future_utils_imports:
            self.future_utils_new_imports.add(func_name)

    @m.leave(iter_matcher)
    def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        attribute = ensure_type(updated_node.func, Attribute)
        func_name = attribute.attr
        dict_name = attribute.value
        return updated_node.with_changes(func=func_name, args=[Arg(dict_name)])

    not_iter_matcher = m.Call(
        func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items"))
    )

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(not_iter_matcher)
    def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.call_if_not_inside(m.Import() | m.ImportFrom())
    @m.leave(m.Name(value="unicode"))
    def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression:
        value = "text_type"
        if value not in self.future_utils_imports:
            self.future_utils_new_imports.add(value)
        return updated_node.with_changes(value=value)

    def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "builtins",
            self.builtins_updated_node,
            self.builtins_imports,
            self.builtins_new_imports,
            True,
        )
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "future.utils",
            self.future_utils_updated_node,
            self.future_utils_imports,
            self.future_utils_new_imports,
            False,
        )
        return updated_node

    def update_imports(
        self,
        original_module: Module,
        updated_module: Module,
        import_name: str,
        updated_import_node: SimpleStatementLine,
        current_imports: Dict[str, str],
        new_imports: Set[str],
        noqa: bool,
    ) -> Module:
        if not new_imports:
            return updated_module
        noqa_comment = "  # noqa" if noqa else ""
        if not updated_import_node:
            i = -1
            blank_lines = "\n\n"
            if self.last_import_node_stmt:
                blank_lines = ""
                for i, (original, updated) in enumerate(
                    zip(original_module.body, updated_module.body)
                ):
                    if original is self.last_import_node_stmt:
                        break
            stmt = parse_module(
                f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}",
                config=updated_module.config_for_parsing,
            )
            body = list(updated_module.body)
            self.last_import_node_stmt = stmt
            return updated_module.with_changes(
                body=body[: i + 1] + stmt.children + body[i + 1 :]
            )
        else:
            if "*" not in current_imports:
                current_imports_set = {
                    f"{k}" if k == v else f"{k} as {v}"
                    for k, v in current_imports.items()
                }
                stmt = parse_statement(
                    f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}"
                )
                return updated_module.deep_replace(updated_import_node, stmt)
                # for i, (original, updated) in enumerate(
                #     zip(original_module.body, updated_module.body)
                # ):
                #     if original is original_import_node:
                #         body = list(updated_module.body)
                #         return updated_module.with_changes(
                #             body=body[:i] + [stmt] + body[i + 1 :]
                #         )
        return updated_module
Example #10
0
class ShedFixers(VisitorBasedCodemodCommand):
    """Fix a variety of small problems.

    Replaces `raise NotImplemented` with `raise NotImplementedError`,
    and converts always-failing assert statements to explicit `raise` statements.

    Also includes code closely modelled on pybetter's fixers, because it's
    considerably faster to run all transforms in a single pass if possible.
    """

    DESCRIPTION = "Fix a variety of style, performance, and correctness issues."

    @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented")))
    def leave_Name(self, _, updated_node):  # noqa
        return updated_node.with_changes(value="NotImplementedError")

    def leave_Assert(self, _, updated_node):  # noqa
        test_code = cst.Module("").code_for_node(updated_node.test)
        try:
            test_literal = literal_eval(test_code)
        except Exception:
            return updated_node
        if test_literal:
            return cst.RemovalSentinel.REMOVE
        if updated_node.msg is None:
            return cst.Raise(cst.Name("AssertionError"))
        return cst.Raise(
            cst.Call(cst.Name("AssertionError"),
                     args=[cst.Arg(updated_node.msg)]))

    @m.leave(
        m.ComparisonTarget(comparator=oneof_names("None", "False", "True"),
                           operator=m.Equal()))
    def convert_none_cmp(self, _, updated_node):
        """Inspired by Pybetter."""
        return updated_node.with_changes(operator=cst.Is())

    @m.leave(
        m.UnaryOperation(
            operator=m.Not(),
            expression=m.Comparison(
                comparisons=[m.ComparisonTarget(operator=m.In())]),
        ))
    def replace_not_in_condition(self, _, updated_node):
        """Also inspired by Pybetter."""
        expr = cst.ensure_type(updated_node.expression, cst.Comparison)
        return cst.Comparison(
            left=expr.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                expr.comparisons[0].with_changes(operator=cst.NotIn())
            ],
        )

    @m.leave(
        m.Call(
            lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())],
            rpar=[m.AtLeastN(n=1, matcher=m.RightParen())],
        ))
    def remove_pointless_parens_around_call(self, _, updated_node):
        # This is *probably* valid, but we might have e.g. a multi-line parenthesised
        # chain of attribute accesses ("fluent interface"), where we need the parens.
        noparens = updated_node.with_changes(lpar=[], rpar=[])
        try:
            compile(self.module.code_for_node(noparens), "<string>", "eval")
            return noparens
        except SyntaxError:
            return updated_node

    # The following methods fix https://pypi.org/project/flake8-comprehensions/

    @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())]))
    def replace_generator_in_call_with_comprehension(self, _, updated_node):
        """Fix flake8-comprehensions C400-402 and 403-404.

        C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension.
        Note that set and dict conversions are handled by pyupgrade!
        """
        return cst.ListComp(elt=updated_node.args[0].value.elt,
                            for_in=updated_node.args[0].value.for_in)

    @m.leave(
        m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")])
        | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")])
        | m.Call(
            func=m.Name("list"),
            args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")],
        ))
    def replace_unnecessary_list_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C411 and C413.

        Unnecessary <list/reversed> call around sorted().

        Also covers C411 Unnecessary list call around list comprehension
        for lists and sets.
        """
        return updated_node.args[0].value

    @m.leave(
        m.Call(
            func=m.Name("reversed"),
            args=[m.Arg(m.Call(func=m.Name("sorted")), star="")],
        ))
    def replace_unnecessary_reversed_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C413.

        Unnecessary reversed call around sorted().
        """
        call = updated_node.args[0].value
        args = list(call.args)
        for i, arg in enumerate(args):
            if m.matches(arg.keyword, m.Name("reverse")):
                try:
                    val = bool(
                        literal_eval(self.module.code_for_node(arg.value)))
                except Exception:
                    args[i] = arg.with_changes(
                        value=cst.UnaryOperation(cst.Not(), arg.value))
                else:
                    if not val:
                        args[i] = arg.with_changes(value=cst.Name("True"))
                    else:
                        del args[i]
                        args[i - 1] = remove_trailing_comma(args[i - 1])
                break
        else:
            args.append(
                cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True")))
        return call.with_changes(args=args)

    _sets = oneof_names("set", "frozenset")
    _seqs = oneof_names("list", "reversed", "sorted", "tuple")

    @m.leave(
        m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")])
        | m.Call(
            func=oneof_names("list", "tuple"),
            args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")],
        )
        | m.Call(
            func=m.Name("sorted"),
            args=[m.Arg(m.Call(func=_seqs), star=""),
                  m.ZeroOrMore()],
        ))
    def replace_unnecessary_nested_calls(self, _, updated_node):
        """Fix flake8-comprehensions C414.

        Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>()..
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.args[0].value)] +
            list(updated_node.args[1:]), )

    @m.leave(
        m.Call(
            func=oneof_names("reversed", "set", "sorted"),
            args=[
                m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)]))
            ],
        ))
    def replace_unnecessary_subscript_reversal(self, _, updated_node):
        """Fix flake8-comprehensions C415.

        Unnecessary subscript reversal of iterable within <reversed/set/sorted>().
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.value)], )

    @m.leave(
        multi(
            m.ListComp,
            m.SetComp,
            elt=m.Name(),
            for_in=m.CompFor(target=m.Name(),
                             ifs=[],
                             inner_for_in=None,
                             asynchronous=None),
        ))
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node

    @m.leave(m.Subscript(oneof_names("Union", "Literal")))
    def reorder_union_literal_contents_none_last(self, _, updated_node):
        subscript = list(updated_node.slice)
        try:
            subscript.sort(key=lambda elt: elt.slice.value.value == "None")
            subscript[-1] = remove_trailing_comma(subscript[-1])
            return updated_node.with_changes(slice=subscript)
        except Exception:  # Single-element literals are not slices, etc.
            return updated_node

    @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation()))
    @m.leave(
        m.BinaryOperation(
            left=m.Name("None") | m.BinaryOperation(),
            operator=m.BitOr(),
            right=m.DoNotCare(),
        ))
    def reorder_union_operator_contents_none_last(self, _, updated_node):
        def _has_none(node):
            if m.matches(node, m.Name("None")):
                return True
            elif m.matches(node, m.BinaryOperation()):
                return _has_none(node.left) or _has_none(node.right)
            else:
                return False

        node_left = updated_node.left
        if _has_none(node_left):
            return updated_node.with_changes(left=updated_node.right,
                                             right=node_left)
        else:
            return updated_node

    @m.leave(m.Subscript(value=m.Name("Literal")))
    def flatten_literal_subscript(self, _, updated_node):
        new_slice = []
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))):
                new_slice += item.slice.value.slice
            else:
                new_slice.append(item)
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Subscript(value=m.Name("Union")))
    def flatten_union_subscript(self, _, updated_node):
        new_slice = []
        has_none = False
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
                new_slice += item.slice.value.slice  # peel off "Optional"
                has_none = True
            elif m.matches(item.slice.value,
                           m.Subscript(m.Name("Union"))) and m.matches(
                               updated_node.value, item.slice.value.value):
                new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
            elif m.matches(item.slice.value, m.Name("None")):
                has_none = True
            else:
                new_slice.append(item)
        if has_none:
            new_slice.append(
                cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])])))
    def discard_empty_else_blocks(self, _, updated_node):
        # An `else: pass` block can always simply be discarded, and libcst ensures
        # that an Else node can only ever occur attached to an If, While, For, or Try
        # node; in each case `None` is the valid way to represent "no else block".
        if m.findall(updated_node, m.Comment()):
            return updated_node  # If there are any comments, keep the node
        return cst.RemoveFromParent()

    @m.leave(
        m.Lambda(params=m.MatchIfTrue(lambda node: (
            node.star_kwarg is None and not node.kwonly_params and not node.
            posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and
            all(param.default is None for param in node.params)))))
    def remove_lambda_indirection(self, _, updated_node):
        same_args = [
            m.Arg(m.Name(param.name.value), star="", keyword=None)
            for param in updated_node.params.params
        ]
        if m.matches(updated_node.body, m.Call(args=same_args)):
            return cst.ensure_type(updated_node.body, cst.Call).func
        return updated_node

    @m.leave(
        m.BooleanOperation(
            left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
            operator=m.Or(),
            right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
        ))
    def collapse_isinstance_checks(self, _, updated_node):
        left_target, left_type = updated_node.left.args
        right_target, right_type = updated_node.right.args
        if left_target.deep_equals(right_target):
            merged_type = cst.Arg(
                cst.Tuple([
                    cst.Element(left_type.value),
                    cst.Element(right_type.value)
                ]))
            return updated_node.left.with_changes(
                args=[left_target, merged_type])
        return updated_node
Example #11
0
def is_docstring(node):
    return m.matches(
        node, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]))
Example #12
0
                                    newline=m.Newline())

_django_model_field_name_value = m.Call(func=m.Attribute(
    attr=m.Name(m.MatchIfTrue(is_model_field_type)))) | m.Call(
        func=m.Name(m.MatchIfTrue(is_model_field_type)))

_django_model_field_name_with_leading_comment_value = m.Call(
    func=m.Attribute(attr=m.Name(m.MatchIfTrue(is_model_field_type))),
    whitespace_before_args=m.ParenthesizedWhitespace(_any_comment),
) | m.Call(
    func=m.Name(m.MatchIfTrue(is_model_field_type)),
    whitespace_before_args=m.ParenthesizedWhitespace(_any_comment),
)

_django_model_field_with_leading_comment = m.SimpleStatementLine(body=[
    m.Assign(value=_django_model_field_name_with_leading_comment_value)
    | m.AnnAssign(value=_django_model_field_name_with_leading_comment_value)
])

_django_model_field_with_trailing_comment = m.SimpleStatementLine(
    body=[
        m.Assign(value=_django_model_field_name_value)
        | m.AnnAssign(value=_django_model_field_name_value)
    ],
    trailing_whitespace=_any_comment,
)

django_model_field_with_comments = (_django_model_field_with_leading_comment |
                                    _django_model_field_with_trailing_comment)


def get_leading_comment(node: cst.SimpleStatementLine) -> typing.Optional[str]:
    comment=m.Comment(m.MatchIfTrue(is_valid_comment)),
    newline=m.Newline(),
)

field_without_comment = m.SimpleStatementLine(
    body=[
        m.Assign(value=(m.Call(
            args=[
                m.ZeroOrMore(),
                m.Arg(keyword=m.Name('null'), value=m.Name('True')),
                m.ZeroOrMore(),
            ],
            whitespace_before_args=m.DoesNotMatch(
                m.ParenthesizedWhitespace(null_comment)),
        )
                        | m.Call(
                            func=m.Attribute(attr=m.Name('NullBooleanField')),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
                        )
                        | m.Call(
                            func=m.Name('NullBooleanField'),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
                        )))
    ],
    trailing_whitespace=m.DoesNotMatch(null_comment),
)


class FieldValidator(m.MatcherDecoratableVisitor):
    METADATA_DEPENDENCIES = (PositionProvider, )
Example #14
0
class LoggerTransformer(cst.CSTTransformer):
    METADATA_DEPENDENCIES = (
            WhitespaceInclusivePositionProvider, PositionProvider, ParentNodeProvider
    )
    
    def __init__(
            self,
            fpath,
            lines,
            default_level='info',
            accept_all=False,
            comment_sep=' / ',
            context_lines=13,
    ):
        self.fpath = fpath
        self.lines = lines
        self.default_level: str = default_level
        self.accept_all: bool = accept_all
        self.comment_sep: str = comment_sep
        self.context_lines: int = context_lines
    
    def get_parent(self, node) -> CSTNodeT:
        return self.get_metadata(cst.metadata.ParentNodeProvider, node)
    
    @m.call_if_not_inside(m.SimpleStatementLine())
    def leave_line(self, original_node, updated_node):
        return updated_node
    
    @m.call_if_inside(m.Call())
    def leave_call(self, original_node, updated_node):
        if not (original_node.func.value == 'print'):
            return line_node
    
        
    
    def on_leave(self,
                 original_node: CSTNodeT,
                 updated_node: CSTNodeT) -> Union[CSTNodeT, RemovalSentinel]:
        # Visit line nodes with print function calls
        if not isinstance(updated_node, SimpleStatementLine):
            return updated_node
        
        original_line_node = original_node
        line_node = updated_node
        
        if not isinstance(line_node.body[0], Expr):
            return updated_node
        
        node = line_node.body[0].value
        original_node = original_node.body[0].value
        
        if not (isinstance(node, Call) and node.func.value == 'print'):
            return line_node
        
        #Arg.value, Arg.keyword
        pos_args = [x.value for x in node.args if not x.keyword]
        
        has_vars = False
        terms = []
        n_variables = 0
        simple_ixs = []  # indexes of regular, simple strings
        for ix, arg in enumerate(pos_args):
            if isinstance(arg, FormattedString):
                for part in arg.parts:
                    if isinstance(part, FormattedStringExpression):
                        has_vars = True
                        break
                term = make_str(arg)
                terms.append(term)
                
            elif isinstance(arg, SimpleString):
                term = extract_string(arg.value)
                terms.append(term)
                simple_ixs.append(ix)
                
            elif isinstance(arg, ConcatenatedString):
                visitor = GatherStringVisitor()
                arg.visit(visitor)
                term = ''.join([extract_string(s) for s in visitor.strings])
                terms.append(term)
                simple_ixs.append(ix)
            
            elif isinstance(arg, Name):
                has_vars = True
                n_variables += 1
                terms.append('{' + arg.value + '}')
        
        # Escape {} in non-f strings
        if has_vars:
            for ix in simple_ixs:
                term = terms[ix]
                terms[ix] = term.replace('{', '{{').replace('}', '}}')
        
        sep = ' '
        sep_ = get_keyword(node, 'sep')
        try:
            # fails if sep is a variable
            sep = extract_string(sep_)
        except TypeError:
            pass
        
        if n_variables == len(terms) == 1:
            # Avoid putting a single variable inside f-string
            arg_line = terms[0]
        else:
            arg_line = '"' + sep.join(terms) + '"'
            if has_vars:
                arg_line = 'f' + arg_line
        
        args = [Arg(value=cst.parse_expression(arg_line))]
        
        # Gather up comments
        cst.metadata.MetadataWrapper(original_line_node)
        cg = GatherCommentsVisitor()
        original_line_node.visit(cg)
        comment = cg.get_joined_comment(self.comment_sep)
        
        # Remove all comments in order to put them all at the end
        rc_trans = RemoveCommentsTransformer()
        line_node = line_node.visit(rc_trans)
        
        def get_line_node(level):
            func = Attribute(value=Name('logging'), attr=Name(level))
            node_ = node.with_changes(func=func, args=args)
            
            line_node_ = line_node.deep_replace(line_node.body[0].value, node_)
            line_node_ = line_node_.with_deep_changes(
                    line_node_.trailing_whitespace, comment=comment
            )
            return line_node_
        
        line_node = get_line_node(self.default_level)
        
        # pos = self.get_metadata(WhitespaceInclusivePositionProvider, original_line_node)
        pos = self.get_metadata(PositionProvider, original_line_node)
        lineix = pos.start.line - 1  # 1 indexed line number
        end_lineix, end_lineno = pos.end.line - 1, pos.end.line
        
        # Predict the source code for the newly changed line node
        module_node = original_line_node
        while not isinstance(module_node, Module):
            module_node = self.get_parent(module_node)
        
        # n_lines = len(cst.parse_module("").code_for_node(line_node).splitlines())
        # new_code = module_node.deep_replace(original_line_node, line_node).code
        # line = '\n'.join(new_code.splitlines()[lineix:lineix + n_lines])
        line = cst.parse_module("").code_for_node(line_node)
        
        # Find the function or class containing the print line
        context_node = original_line_node
        while not isinstance(context_node, (FunctionDef, ClassDef, Module)):
            context_node = self.get_parent(context_node)
        if isinstance(context_node, Module):
            source_context = ''
        else:
            source_context = '/' + context_node.name.value
        
        print(
                Bcolor.HEADER, f"{self.fpath}{source_context}:"
                               f"{lineix + 1}-{end_lineix + 1}", Bcolor.ENDC
        )
        print()
        print_context2(self.lines, lineix, end_lineno, line, self.context_lines)
        print()
        
        import ipdb
        
        # ipdb.set_trace()
        # Query the user to decide whether to accept, modify, or reject changes
        if self.accept_all:
            return line_node
        inp = None
        while inp not in ['', 'y', 'n', 'A', 'i', 'w', 'e', 'c', 'x', 'q']:
            inp = input(
                    Bcolor.OKCYAN + "Accept change? ("
                                    f"y = yes ({self.default_level}) [default], "
                                    "n = no, "
                                    "A = yes to all, "
                                    "i = yes (info), "
                                    "w = yes (warning), "
                                    "e = yes (error), "
                                    "c = yes (critical), "
                                    "x = yes (exception), "
                                    "q = quit): " + Bcolor.ENDC
            )
        if inp in ('q', 'Q'):
            sys.exit(0)
        elif inp == 'n':
            return original_line_node
        elif inp in ['i', 'w', 'e', 'c', 'x']:
            level = levels[inp]
            line_node = get_line_node(level)
        elif inp == 'A':
            self.accept_all = True
        
        return line_node
Example #15
0
def inline_function(func_obj,
                    call,
                    ret_var,
                    cls=None,
                    f_ast=None,
                    is_toplevel=False):
    log.debug('Inlining {}'.format(a2s(call)))

    inliner = ctx_inliner.get()
    pass_ = ctx_pass.get()

    if f_ast is None:
        # Get the source code for the function
        try:
            f_source = inspect.getsource(func_obj)
        except TypeError:
            print('Failed to get source of {}'.format(a2s(call)))
            raise

        # Record statistics about length of inlined source
        inliner.length_inlined += len(f_source.split('\n'))

        # Then parse the function into an AST
        f_ast = parse_statement(f_source)

    # Give the function a fresh name so it won't conflict with other calls to
    # the same function
    f_ast = f_ast.with_changes(
        name=cst.Name(pass_.fresh_var(f_ast.name.value)))

    # TODO
    # If function has decorators, deal with those first. Just inline decorator call
    # and stop there.
    decorators = f_ast.decorators
    assert len(decorators) <= 1  # TODO: deal with multiple decorators
    if len(decorators) == 1:
        d = decorators[0].decorator
        builtin_decorator = (isinstance(d, cst.Name) and
                             (d.value
                              in ['property', 'classmethod', 'staticmethod']))
        derived_decorator = (isinstance(d, cst.Attribute)
                             and (d.attr.value in ['setter']))
        if not (builtin_decorator or derived_decorator):
            return inline_decorators(f_ast, call, func_obj, ret_var)

    # # If we're inlining a decorator, we need to remove @functools.wraps calls
    # # to avoid messing up inspect.getsource
    f_ast = f_ast.with_changes(body=f_ast.body.visit(RemoveFunctoolsWraps()))

    new_stmts = []

    # If the function is a method (which we proxy by first arg being named "self"),
    # then we need to replace uses of special "super" keywords.
    args_def = f_ast.params
    if len(args_def.params) > 0:
        first_arg_is_self = m.matches(args_def.params[0],
                                      m.Param(m.Name('self')))
        if first_arg_is_self:
            f_ast = replace_super(f_ast, cls, call, func_obj, new_stmts)

    # Add bindings from arguments in the call expression to arguments in function def
    f_ast = bind_arguments(f_ast, call, new_stmts)

    scopes = cst.MetadataWrapper(
        f_ast, unsafe_skip_copy=True).resolve(ScopeProviderFunction)
    func_scope = scopes[f_ast.body]

    for assgn in func_scope.assignments:
        if m.matches(assgn.node, m.Name()):
            var = assgn.node.value
            f_ast = unique_and_rename(f_ast, var)

    # Add an explicit return None at the end to reify implicit return
    f_body = f_ast.body
    last_stmt_is_return = m.matches(f_body.body[-1],
                                    m.SimpleStatementLine([m.Return()]))
    if (not is_toplevel and  # If function return is being assigned
            cls is None and  # And not an __init__ fn
            not last_stmt_is_return):
        f_ast = f_ast.with_deep_changes(f_body,
                                        body=list(f_body.body) +
                                        [parse_statement("return None")])

    # Replace returns with if statements
    f_ast = f_ast.with_changes(body=f_ast.body.visit(ReplaceReturn(ret_var)))

    # Inline function body
    new_stmts.extend(f_ast.body.body)

    # Create imports for non-local variables
    imports = generate_imports_for_nonlocals(f_ast, func_obj, call)
    new_stmts = imports + new_stmts

    if inliner.add_comments:
        # Add header comment to first statement
        call_str = a2s(call)
        header_comment = [
            cst.EmptyLine(comment=cst.Comment(f'# {line}'))
            for line in call_str.splitlines()
        ]
        first_stmt = new_stmts[0]
        new_stmts[0] = first_stmt.with_changes(
            leading_lines=[cst.EmptyLine(indent=False)] + header_comment +
            list(first_stmt.leading_lines))

    return new_stmts