Esempio n. 1
0
    def test_replace_sequence_extract(self) -> None:
        def _reverse_params(
            node: cst.CSTNode,
            extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
        ) -> cst.CSTNode:
            return cst.ensure_type(node, cst.FunctionDef).with_changes(
                # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that
                # to pyre is difficult.
                params=cst.Parameters(
                    params=list(reversed(extraction["params"]))))

        # Verify that we can still extract sequences with replace.
        original = cst.parse_module(
            "def bar(baz: int, foo: int, ) -> int:\n    return baz + foo\n")
        replaced = cst.ensure_type(
            m.replace(
                original,
                m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode(
                    [m.ZeroOrMore(m.Param())], "params"))),
                _reverse_params,
            ),
            cst.Module,
        ).code
        self.assertEqual(
            replaced,
            "def bar(foo: int, baz: int, ) -> int:\n    return baz + foo\n")
Esempio n. 2
0
 def test_replace_simple_sentinel(self) -> None:
     # Verify behavior when there's a sentinel as a replacement
     original = cst.parse_module(
         "def bar(x: int, y: int) -> bool:\n    return False\n")
     replaced = cst.ensure_type(
         m.replace(original, m.Param(), cst.RemoveFromParent()),
         cst.Module).code
     self.assertEqual(replaced, "def bar() -> bool:\n    return False\n")
Esempio n. 3
0
class MakeModalCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Replace built-in method MAkeModal with helper"

    method_matcher = matchers.FunctionDef(
        name=matchers.Name(value="MakeModal"),
        params=matchers.Parameters(params=[
            matchers.Param(name=matchers.Name(value="self")),
            matchers.ZeroOrMore()
        ]),
    )
    call_matcher = matchers.Call(
        func=matchers.Attribute(value=matchers.Name(value="self"),
                                attr=matchers.Name(value="MakeModal")))

    method_cst = cst.parse_statement(
        textwrap.dedent("""
            def MakeModal(self, modal=True):
                if modal and not hasattr(self, '_disabler'):
                    self._disabler = wx.WindowDisabler(self)
                if not modal and hasattr(self, '_disabler'):
                    del self._disabler
            """))

    def __init__(self, context: CodemodContext):
        super().__init__(context)

        self.stack: List[cst.ClassDef] = []

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        self.stack.append(node)

    def leave_ClassDef(self, original_node: cst.ClassDef,
                       updated_node: cst.ClassDef) -> cst.ClassDef:
        return self.stack.pop()

    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.call_matcher):
            # Search for MakeModal() method
            current_class = self.stack[-1]
            has_make_modal_method = False

            for method in current_class.body.body:
                if matchers.matches(method, self.method_matcher):
                    has_make_modal_method = True

            # If not, add it to the current class
            if not has_make_modal_method:
                current_class = current_class.with_changes(
                    body=current_class.body.with_changes(
                        body=[*current_class.body.body, self.method_cst]))

                self.stack[-1] = current_class

        return updated_node
Esempio n. 4
0
 def __get_fn_params(self, fn_params: cst.Parameters):
     p_names: List[str] = []
     kwarg = [fn_params.star_kwarg
              ] if fn_params.star_kwarg is not None else []
     stararg = [fn_params.star_arg] if match.matches(
         fn_params.star_arg,
         match.Param(name=match.Name(value=match.DoNotCare()))) else []
     for p in list(fn_params.params) + list(fn_params.kwonly_params) + list(
             fn_params.posonly_params) + stararg + kwarg:
         p_names.append(self.nlp_p(p.name.value))
     return p_names
Esempio n. 5
0
    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
        modified_defaults: List = []
        mutable_args: List[Tuple[cst.Name, Union[cst.List, cst.Dict]]] = []

        for param in updated_node.params.params:
            if not m.matches(param,
                             m.Param(default=m.OneOf(m.List(), m.Dict()))):
                modified_defaults.append(param)
                continue

            # This line here is just for type checkers peace of mind,
            # since it cannot reason about variables from matchers result.
            if not isinstance(param.default, (cst.List, cst.Dict)):
                continue

            mutable_args.append((param.name, param.default))
            modified_defaults.append(
                param.with_changes(default=cst.Name("None"), ))

        if not mutable_args:
            return original_node

        modified_params: cst.Parameters = updated_node.params.with_changes(
            params=modified_defaults)

        initializations: List[Union[
            cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [
                # We use generation by template here since construction of the
                # resulting 'if' can be burdensome due to many nested objects
                # involved. Additional line is attached so that we may control
                # exact spacing between generated statements.
                parse_template_statement(
                    DEFAULT_INIT_TEMPLATE,
                    config=self.module_config,
                    arg=arg,
                    init=init).with_changes(leading_lines=[EMPTY_LINE])
                for arg, init in mutable_args
            ]

        # Docstring should always go right after the function definition,
        # so we take special care to insert our initializations after the
        # last docstring found.
        docstrings = takewhile(is_docstring, updated_node.body.body)
        function_code = dropwhile(is_docstring, updated_node.body.body)

        # It is not possible to insert empty line after the statement line,
        # because whitespace is owned by the next statement after it.
        stmt_with_empty_line = next(function_code).with_changes(
            leading_lines=[EMPTY_LINE])

        modified_body = (
            *docstrings,
            *initializations,
            stmt_with_empty_line,
            *function_code,
        )

        return updated_node.with_changes(
            params=modified_params,
            body=updated_node.body.with_changes(body=modified_body),
        )
Esempio n. 6
0
def inline_function(func_obj,
                    call,
                    ret_var,
                    cls=None,
                    f_ast=None,
                    is_toplevel=False):
    log.debug('Inlining {}'.format(a2s(call)))

    inliner = ctx_inliner.get()
    pass_ = ctx_pass.get()

    if f_ast is None:
        # Get the source code for the function
        try:
            f_source = inspect.getsource(func_obj)
        except TypeError:
            print('Failed to get source of {}'.format(a2s(call)))
            raise

        # Record statistics about length of inlined source
        inliner.length_inlined += len(f_source.split('\n'))

        # Then parse the function into an AST
        f_ast = parse_statement(f_source)

    # Give the function a fresh name so it won't conflict with other calls to
    # the same function
    f_ast = f_ast.with_changes(
        name=cst.Name(pass_.fresh_var(f_ast.name.value)))

    # TODO
    # If function has decorators, deal with those first. Just inline decorator call
    # and stop there.
    decorators = f_ast.decorators
    assert len(decorators) <= 1  # TODO: deal with multiple decorators
    if len(decorators) == 1:
        d = decorators[0].decorator
        builtin_decorator = (isinstance(d, cst.Name) and
                             (d.value
                              in ['property', 'classmethod', 'staticmethod']))
        derived_decorator = (isinstance(d, cst.Attribute)
                             and (d.attr.value in ['setter']))
        if not (builtin_decorator or derived_decorator):
            return inline_decorators(f_ast, call, func_obj, ret_var)

    # # If we're inlining a decorator, we need to remove @functools.wraps calls
    # # to avoid messing up inspect.getsource
    f_ast = f_ast.with_changes(body=f_ast.body.visit(RemoveFunctoolsWraps()))

    new_stmts = []

    # If the function is a method (which we proxy by first arg being named "self"),
    # then we need to replace uses of special "super" keywords.
    args_def = f_ast.params
    if len(args_def.params) > 0:
        first_arg_is_self = m.matches(args_def.params[0],
                                      m.Param(m.Name('self')))
        if first_arg_is_self:
            f_ast = replace_super(f_ast, cls, call, func_obj, new_stmts)

    # Add bindings from arguments in the call expression to arguments in function def
    f_ast = bind_arguments(f_ast, call, new_stmts)

    scopes = cst.MetadataWrapper(
        f_ast, unsafe_skip_copy=True).resolve(ScopeProviderFunction)
    func_scope = scopes[f_ast.body]

    for assgn in func_scope.assignments:
        if m.matches(assgn.node, m.Name()):
            var = assgn.node.value
            f_ast = unique_and_rename(f_ast, var)

    # Add an explicit return None at the end to reify implicit return
    f_body = f_ast.body
    last_stmt_is_return = m.matches(f_body.body[-1],
                                    m.SimpleStatementLine([m.Return()]))
    if (not is_toplevel and  # If function return is being assigned
            cls is None and  # And not an __init__ fn
            not last_stmt_is_return):
        f_ast = f_ast.with_deep_changes(f_body,
                                        body=list(f_body.body) +
                                        [parse_statement("return None")])

    # Replace returns with if statements
    f_ast = f_ast.with_changes(body=f_ast.body.visit(ReplaceReturn(ret_var)))

    # Inline function body
    new_stmts.extend(f_ast.body.body)

    # Create imports for non-local variables
    imports = generate_imports_for_nonlocals(f_ast, func_obj, call)
    new_stmts = imports + new_stmts

    if inliner.add_comments:
        # Add header comment to first statement
        call_str = a2s(call)
        header_comment = [
            cst.EmptyLine(comment=cst.Comment(f'# {line}'))
            for line in call_str.splitlines()
        ]
        first_stmt = new_stmts[0]
        new_stmts[0] = first_stmt.with_changes(
            leading_lines=[cst.EmptyLine(indent=False)] + header_comment +
            list(first_stmt.leading_lines))

    return new_stmts