コード例 #1
0
    def leave_Call(self, original_node: Call,
                   updated_node: Call) -> BaseExpression:
        """
        Remove the `weak` argument if present in the call.

        This is only changing calls with keyword arguments.
        """
        if self.disconnect_call_matchers and m.matches(
                updated_node, m.OneOf(*self.disconnect_call_matchers)):
            updated_args = []
            should_change = False
            last_comma = MaybeSentinel.DEFAULT
            # Keep all arguments except the one with the keyword `weak` (if present)
            for index, arg in enumerate(updated_node.args):
                if m.matches(arg, m.Arg(keyword=m.Name("weak"))):
                    # An argument with the keyword `weak` was found
                    # -> we need to rewrite the statement
                    should_change = True
                else:
                    updated_args.append(arg)
                last_comma = arg.comma
            if should_change:
                # Make sure the end of line is formatted as initially
                updated_args[-1] = updated_args[-1].with_changes(
                    comma=last_comma)
                return updated_node.with_changes(args=updated_args)
        return super().leave_Call(original_node, updated_node)
コード例 #2
0
ファイル: Transformers.py プロジェクト: 1MrEnot/Obf2
    def new_obf_function_name(self, func: cst.Call):

        func_name = func.func

        # Обфускация имени функции
        if m.matches(func_name, m.Attribute()):
            func_name = cst.ensure_type(func_name, cst.Attribute)

            # Переименовывание имени
            if self.change_variables:
                func_name = func_name.with_changes(
                    value=self.obf_universal(func_name.value, 'v'))

            # Переименовывание метода
            if self.change_methods:
                func_name = func_name.with_changes(
                    attr=self.obf_universal(func_name.attr, 'cf'))

        elif m.matches(func_name, m.Name()):
            func_name = cst.ensure_type(func_name, cst.Name)
            if (self.change_functions
                    or self.change_classes) and self.can_rename(
                        func_name.value, 'c', 'f'):
                func_name = self.get_new_cst_name(func_name.value)

        else:
            pass

        func = func.with_changes(func=func_name)

        return func
コード例 #3
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(value=a.value, keyword=cst.Name(value=ctrl))
            for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        request_arg = cst.Arg(
            value=cst.Dict([
                cst.DictElement(cst.SimpleString("'{}'".format(name)),
                                cst.Element(value=arg.value))
                # Note: the args + kwargs looks silly, but keep in mind that
                # the control parameters had to be stripped out, and that
                # those could have been passed positionally or by keyword.
                for name, arg in zip(kword_params, args + kwargs)
            ]),
            keyword=cst.Name("request"))

        return updated.with_changes(args=[request_arg] + ctrl_kwargs)
コード例 #4
0
ファイル: Transformers.py プロジェクト: 1MrEnot/Obf2
    def obf_function_args(self, func: cst.Call):

        new_args = []
        func_root = func.func
        func_name = ''

        if m.matches(func_root, m.Name()):
            func_name = cst.ensure_type(func_root, cst.Name).value
        elif m.matches(func_root, m.Attribute()):
            func_name = split_attribute(
                cst.ensure_type(func_root, cst.Attribute))[-1]

        if self.change_arguments or self.change_method_arguments:

            for arg in func.args:
                # Значения аргументов
                arg = arg.with_changes(value=self.obf_universal(arg.value))
                # Имена аргументов
                if arg.keyword is not None and self.can_rename_func_param(
                        arg.keyword.value, func_name):
                    new_keyword = self.get_new_cst_name(
                        arg.keyword) if arg.keyword is not None else None
                    arg = arg.with_changes(keyword=new_keyword)

                new_args.append(arg)

        func = func.with_changes(args=new_args)

        return func
コード例 #5
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.matcher):
            return updated_node.with_changes(
                func=updated_node.func.with_changes(attr=cst.Name(
                    value="InsertColumn")))

        return updated_node
コード例 #6
0
ファイル: await_async_call.py プロジェクト: tsx/Fixit
 def _get_async_call_replacement(self,
                                 node: cst.Call) -> Optional[cst.CSTNode]:
     func = node.func
     if m.matches(func, m.Attribute()):
         func = cast(cst.Attribute, func)
         attr_func_replacement = self._get_async_attr_replacement(func)
         if attr_func_replacement is not None:
             return node.with_changes(func=attr_func_replacement)
     return self._get_awaitable_replacement(node)
コード例 #7
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.matcher):
            return updated_node.with_changes(args=[
                *updated_node.args,
                cst.Arg(value=cst.Integer(value="0"))
            ])

        return updated_node
コード例 #8
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if m.matches(updated_node, m.Call(func=m.Name("print"))):
         AddImportsVisitor.add_needed_import(
             self.context,
             "pprint",
             "pprint",
         )
         return updated_node.with_changes(func=Name("pprint"))
     return super().leave_Call(original_node, updated_node)
コード例 #9
0
    def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
        if not self.in_coroutine(self.coroutine_stack):
            return updated_node

        if m.matches(updated_node, gen_sleep_matcher):
            self.required_imports.add("asyncio")
            return updated_node.with_changes(func=cst.Attribute(
                value=cst.Name("asyncio"), attr=cst.Name("sleep")))

        return updated_node
コード例 #10
0
 def _gen_builtin_call(self, node: cst.Call) -> cst.Call:
     if not node.args:
         return node
     value = node.args[0].value
     if isinstance(value, cst.ListComp):
         pars: dict = {"lpar": [], "rpar": []} if len(node.args) == 1 else {}
         arg0 = node.args[0].with_changes(
             value=cst.GeneratorExp(elt=value.elt, for_in=value.for_in, **pars)
         )
         return node.with_changes(args=(arg0, *node.args[1:]))
     if isinstance(value, cst.GeneratorExp):
         if len(node.args) == 1 and value.lpar:
             arg0 = node.args[0].with_changes(
                 value=cst.GeneratorExp(
                     elt=value.elt, for_in=value.for_in, lpar=[], rpar=[]
                 )
             )
             return node.with_changes(args=(arg0, *node.args[1:]))
     return node
コード例 #11
0
ファイル: cst_utc.py プロジェクト: KGerring/metaproj
	def datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Name(value="datetime"), attr=cst.Name("now")
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
コード例 #12
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Matches calls with symbols without the wx prefix
        for symbol, matcher, renamed in self.matchers_short_map:
            if symbol in self.wx_imports and matchers.matches(
                    updated_node, matcher):
                # Remove the symbol's import
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)

                # Add import of top level wx package
                AddImportsVisitor.add_needed_import(self.context, "wx")

                # Return updated node
                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(func=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value=renamed)))

        # Matches full calls like wx.MySymbol
        for matcher, renamed in self.matchers_full_map:
            if matchers.matches(updated_node, matcher):

                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(
                    func=updated_node.func.with_changes(attr=cst.Name(
                        value=renamed)))

        # Returns updated node
        return updated_node
コード例 #13
0
 def _set_call(self, node: cst.Call) -> Union[cst.Call, cst.Set, cst.SetComp]:
     if len(node.args) != 1:
         return node
     value = node.args[0].value
     if isinstance(value, (cst.List, cst.Tuple)):
         if value.elements:
             return cst.Set(elements=value.elements)
         else:
             return node.with_changes(args=[])
     if isinstance(value, (cst.ListComp, cst.SetComp, cst.GeneratorExp)):
         return cst.SetComp(elt=value.elt, for_in=value.for_in)
     return node
コード例 #14
0
 def leave_Call(self, original_node: cst.Call) -> None:
     if self.current_classes and m.matches(
         original_node,
         m.Call(
             func=m.Name("super"),
             args=[
                 m.Arg(value=self._build_arg_class_matcher()),
                 m.Arg(),
             ],
         ),
     ):
         self.report(original_node, replacement=original_node.with_changes(args=()))
コード例 #15
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.call_matcher):
            # Update method's call
            updated_node = updated_node.with_changes(
                func=updated_node.func.with_changes(attr=cst.Name(
                    value="AddTool")))

            # Transform keywords
            updated_node_args = list(updated_node.args)

            for arg_matcher, renamed in self.args_matchers_map.items():
                for i, node_arg in enumerate(updated_node.args):
                    if matchers.matches(node_arg, arg_matcher):
                        updated_node_args[i] = node_arg.with_changes(
                            keyword=cst.Name(value=renamed))

                updated_node = updated_node.with_changes(
                    args=updated_node_args)

        return updated_node
コード例 #16
0
 def leave_Call(
     self,
     original_node: cst.Call,
     updated_node: cst.Call,
 ) -> cst.Call:
     if len(updated_node.args) < self.argument_count:
         return updated_node
     else:
         last_arg = updated_node.args[-1]
         return updated_node.with_changes(args=(
             *updated_node.args[:-1],
             last_arg.with_changes(comma=cst.Comma()),
         ), )
コード例 #17
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Migrate form deprecated method AppendItem()
        if matchers.matches(updated_node, self.deprecated_call_matcher):
            updated_node = updated_node.with_changes(
                func=updated_node.func.with_changes(attr=cst.Name(
                    value="Append")))

        # Update keywords
        if matchers.matches(updated_node, self.call_matcher):
            updated_node_args = list(updated_node.args)

            for arg_matcher, renamed in self.args_matchers_map.items():
                for i, node_arg in enumerate(updated_node.args):
                    if matchers.matches(node_arg, arg_matcher):
                        updated_node_args[i] = node_arg.with_changes(
                            keyword=cst.Name(value=renamed))

                updated_node = updated_node.with_changes(
                    args=updated_node_args)

        return updated_node
コード例 #18
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if (self.is_visiting_subclass and m.matches(
             updated_node,
             m.Call(func=m.Attribute(
                 attr=m.Name("has_add_permission"),
                 value=m.Call(func=m.Name("super")),
             )),
     ) and len(updated_node.args) < 2):
         updated_args = (
             *updated_node.args,
             parse_arg("obj=obj"),
         )
         return updated_node.with_changes(args=updated_args)
     return super().leave_Call(original_node, updated_node)
コード例 #19
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if (is_one_to_one_field(original_node) or is_foreign_key(original_node)
         ) and not has_on_delete(original_node):
         AddImportsVisitor.add_needed_import(
             context=self.context,
             module="django.db",
             obj="models",
         )
         updated_args = (
             *updated_node.args,
             parse_arg("on_delete=models.CASCADE"),
         )
         return updated_node.with_changes(args=updated_args)
     return super().leave_Call(original_node, updated_node)
コード例 #20
0
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.BaseExpression:
        if original_node == self.call_node:
            new_args = []
            for arg in updated_node.args:
                if isinstance(arg.keyword, cst.Name):
                    if arg.keyword.value in self.keywords_to_change:
                        value = self.keywords_to_change[arg.keyword.value]
                        if value is not None:
                            new_args.append(arg.with_changes(value=value))
                        # else don't append
                    else:
                        new_args.append(arg)
                else:
                    new_args.append(arg)
            return updated_node.with_changes(args=new_args)

        return updated_node
コード例 #21
0
 def leave_Call(self, original_node: libcst.Call,
                updated_node: libcst.Call) -> libcst.Call:
     check_types = False
     uses_pyre = True
     updated_fields = []
     for field in original_node.args:
         name = field.keyword
         value = field.value
         if not name:
             continue
         name = name.value
         if name == "check_types":
             if isinstance(value, libcst.Name):
                 check_types = check_types or value.value.lower() == "true"
         elif name == "check_types_options":
             if isinstance(value, libcst.SimpleString):
                 uses_pyre = uses_pyre and "mypy" not in value.value.lower()
         elif name not in ["typing", "typing_options"]:
             updated_fields.append(field)
     if check_types and uses_pyre:
         return updated_node.with_changes(args=updated_fields)
     return updated_node
コード例 #22
0
    def visit_Call(self, node: cst.Call) -> None:
        match_compare_is_none = m.ComparisonTarget(
            m.SaveMatchedNode(
                m.OneOf(m.Is(), m.IsNot()),
                "comparison_type",
            ),
            comparator=m.Name("None"),
        )
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(
                    value=m.Name("self"),
                    attr=m.SaveMatchedNode(
                        m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")),
                        "assertion_name",
                    ),
                ),
                args=[
                    m.Arg(
                        m.SaveMatchedNode(
                            m.OneOf(
                                m.Comparison(
                                    comparisons=[match_compare_is_none]),
                                m.UnaryOperation(
                                    operator=m.Not(),
                                    expression=m.Comparison(
                                        comparisons=[match_compare_is_none]),
                                ),
                            ),
                            "argument",
                        ))
                ],
            ),
        )

        if result:
            assertion_name = result["assertion_name"]
            if isinstance(assertion_name, Sequence):
                assertion_name = assertion_name[0]

            argument = result["argument"]
            if isinstance(argument, Sequence):
                argument = argument[0]

            comparison_type = result["comparison_type"]
            if isinstance(comparison_type, Sequence):
                comparison_type = comparison_type[0]

            if m.matches(argument, m.Comparison()):
                assertion_argument = ensure_type(argument, cst.Comparison).left
            else:
                assertion_argument = ensure_type(
                    ensure_type(argument, cst.UnaryOperation).expression,
                    cst.Comparison).left

            negations_seen = 0
            if m.matches(assertion_name, m.Name("assertFalse")):
                negations_seen += 1
            if m.matches(argument, m.UnaryOperation()):
                negations_seen += 1
            if m.matches(comparison_type, m.IsNot()):
                negations_seen += 1

            new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone"
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name(new_attr)),
                args=[cst.Arg(assertion_argument)],
            )

            if new_call is not node:
                self.report(node, replacement=new_call)
コード例 #23
0
    def visit_Call(self, node: cst.Call) -> None:
        # Todo: Make use of single extract instead of having several
        # if else statemenets to make the code more robust and readable.
        if m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(operator=m.In())
                            ]))
                    ],
                ),
        ):
            # self.assertTrue(a in b) -> self.assertIn(a, b)
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIn")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left),
                    cst.Arg(
                        ensure_type(node.args[0].value,
                                    cst.Comparison).comparisons[0].comparator),
                ],
            )
            self.report(node, replacement=new_call)
        else:
            # ... -> self.assertNotIn(a, b)
            matched, arg1, arg2 = False, None, None
            if m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertTrue")),
                        args=[
                            m.Arg(
                                m.UnaryOperation(
                                    operator=m.Not(),
                                    expression=m.Comparison(comparisons=[
                                        m.ComparisonTarget(operator=m.In())
                                    ]),
                                ))
                        ],
                    ),
            ):
                # self.assertTrue(not a in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(
                        ensure_type(node.args[0].value,
                                    cst.UnaryOperation).expression,
                        cst.Comparison,
                    ).left)
                arg2 = cst.Arg(
                    ensure_type(
                        ensure_type(node.args[0].value,
                                    cst.UnaryOperation).expression,
                        cst.Comparison,
                    ).comparisons[0].comparator)
            elif m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertTrue")),
                        args=[
                            m.Arg(
                                m.Comparison(comparisons=[
                                    m.ComparisonTarget(m.NotIn())
                                ]))
                        ],
                    ),
            ):
                # self.assertTrue(a not in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(node.args[0].value, cst.Comparison).left)
                arg2 = cst.Arg(
                    ensure_type(node.args[0].value,
                                cst.Comparison).comparisons[0].comparator)
            elif m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertFalse")),
                        args=[
                            m.Arg(
                                m.Comparison(
                                    comparisons=[m.ComparisonTarget(m.In())]))
                        ],
                    ),
            ):
                # self.assertFalse(a in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(node.args[0].value, cst.Comparison).left)
                arg2 = cst.Arg(
                    ensure_type(node.args[0].value,
                                cst.Comparison).comparisons[0].comparator)

            if matched:
                new_call = node.with_changes(
                    func=cst.Attribute(value=cst.Name("self"),
                                       attr=cst.Name("assertNotIn")),
                    args=[arg1, arg2],
                )
                self.report(node, replacement=new_call)
コード例 #24
0
ファイル: modernizer.py プロジェクト: ybastide/cst-test
 def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
     attribute = ensure_type(updated_node.func, Attribute)
     func_name = attribute.attr
     dict_name = attribute.value
     return updated_node.with_changes(func=func_name, args=[Arg(dict_name)])
コード例 #25
0
ファイル: modernizer.py プロジェクト: ybastide/cst-test
 def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression:
     orig_func_name = ensure_type(updated_node.func, Name).value
     func_name = "range" if orig_func_name == "xrange" else "input"
     return updated_node.with_changes(func=Name(func_name))
コード例 #26
0
    def visit_Call(self, node: cst.Call) -> None:
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(value=m.Name("self"),
                                 attr=m.Name("assertTrue")),
                args=[
                    m.DoNotCare(),
                    m.Arg(value=m.SaveMatchedNode(
                        m.OneOf(
                            m.Integer(),
                            m.Float(),
                            m.Imaginary(),
                            m.Tuple(),
                            m.List(),
                            m.Set(),
                            m.Dict(),
                            m.Name("None"),
                            m.Name("True"),
                            m.Name("False"),
                        ),
                        "second",
                    )),
                ],
            ),
        )

        if result:
            second_arg = result["second"]
            if isinstance(second_arg, Sequence):
                second_arg = second_arg[0]

            if m.matches(second_arg, m.Name("True")):
                new_call = node.with_changes(args=[
                    node.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT)
                ], )
            elif m.matches(second_arg, m.Name("None")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertIsNone",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            elif m.matches(second_arg, m.Name("False")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertFalse",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            else:
                new_call = node.with_deep_changes(
                    old_node=cst.ensure_type(node.func, cst.Attribute).attr,
                    value="assertEqual",
                )

            self.report(node, replacement=new_call)
コード例 #27
0
 def update_call(self, updated_node: Call) -> BaseExpression:
     updated_args = self.update_call_args(updated_node)
     return updated_node.with_changes(args=updated_args, func=self.new_func)
コード例 #28
0
    def visit_Call(self, node: cst.Call) -> None:
        # `self.assertTrue(x is not None)` -> `self.assertIsNotNone(x)`
        if m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.IsNot(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)

        # `self.assertTrue(not x is None)` -> `self.assertIsNotNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(value=m.UnaryOperation(
                            operator=m.Not(),
                            expression=m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]),
                        ))
                    ],
                ),
        ):

            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(
                            ensure_type(node.args[0].value,
                                        cst.UnaryOperation).expression,
                            cst.Comparison,
                        ).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertFalse(x is None)` -> `self.assertIsNotNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertTrue(x is None)` -> `self.assertIsNotNone(x))
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)

        # `self.assertFalse(x is not None)` -> `self.assertIsNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.IsNot(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertFalse(not x is None)` -> `self.assertIsNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(value=m.UnaryOperation(
                            operator=m.Not(),
                            expression=m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]),
                        ))
                    ],
                ),
        ):

            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(
                            ensure_type(node.args[0].value,
                                        cst.UnaryOperation).expression,
                            cst.Comparison,
                        ).left)
                ],
            )
            self.report(node, replacement=new_call)
コード例 #29
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(
                value=a.value,
                keyword=cst.Name(value=ctrl),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            ) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        if self._use_keywords:
            new_kwargs = [
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=name),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for name, arg in zip(kword_params, args + kwargs)
            ]
            new_kwargs.extend([
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=arg.keyword.value),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for arg in ctrl_kwargs
            ])
            return updated.with_changes(args=new_kwargs)
        else:
            request_arg = cst.Arg(
                value=cst.Dict([
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(name)),
                        cst.Element(value=arg.value),
                    ) for name, arg in zip(kword_params, args + kwargs)
                ] + [
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(arg.keyword.value)),
                        cst.Element(value=arg.value),
                    ) for arg in ctrl_kwargs
                ]),
                keyword=cst.Name("request"),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            )

            return updated.with_changes(args=[request_arg])