Exemplo n.º 1
0
    def __extract_names_multi_assign(self, elements):
        # Add self vars. in tuple assignments, e.g. self.x, self.y = 1, 2
        # Adds variables in tuple(s) in multiple assignments, e.g. a, (b, c) = 1, (2, 3)
        names: List[cst.Name] = []
        i = 0
        while i < len(elements):
            if match.matches(
                    elements[i],
                    match.Element(value=match.Name(value=match.DoNotCare()))):
                names.append(elements[i].value)
            elif match.matches(
                    elements[i],
                    match.Element(value=match.Attribute(attr=match.Name(
                        value=match.DoNotCare())))):
                names.append(elements[i].value)
            elif match.matches(
                    elements[i],
                    match.Element(value=match.Tuple(
                        elements=match.DoNotCare()))):
                elements.extend(
                    match.findall(
                        elements[i].value,
                        match.Element(value=match.OneOf(
                            match.Attribute(attr=match.Name(
                                value=match.DoNotCare())),
                            match.Name(value=match.DoNotCare())))))
            i += 1

        return names
Exemplo n.º 2
0
    def process_variable(self, node: Union[cst.BaseExpression,
                                           cst.BaseAssignTargetExpression]):

        if m.matches(node, m.Name()):
            node = cst.ensure_type(node, cst.Name)

            if self.class_stack and not self.function_stack:
                self.class_stack[-1].variables.append(node.value)
            else:
                self.info.variables.append(node.value)

        elif m.matches(node, m.Attribute()):
            node = cst.ensure_type(node, cst.Attribute)

            splitted_attributes = split_attribute(node)

            if splitted_attributes[
                    0] == 'self' and self.class_stack and self.function_stack and len(
                        splitted_attributes) > 1:
                self.class_stack[-1].variables.append(splitted_attributes[1])
            else:
                self.info.variables.append(splitted_attributes[0])

        elif m.matches(node, m.Tuple()):
            node = cst.ensure_type(node, cst.Tuple)
            for el in node.elements:
                self.process_variable(el.value)

        else:
            pass
Exemplo n.º 3
0
 def _func_name(self, func):
     if m.matches(func, m.Name()):
         return func.value
     elif m.matches(func, m.Attribute()):
         return func.attr.value
     else:
         return 'func'
Exemplo n.º 4
0
 def test_at_least_n_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguments after that are
     # strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.SimpleString()), n=2),
                 ),
             ),
         ))
     # Fail to match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least three wildcard arguments after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)),
             ),
         ))
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguements that are integers with
     # the value 2 after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.Integer("2")), n=2),
                 ),
             ),
         ))
Exemplo n.º 5
0
    def leave_Comparison(self, original_node: cst.Comparison,
                         updated_node: cst.Comparison) -> cst.BaseExpression:
        remaining_targets: List[cst.ComparisonTarget] = []

        for target in original_node.comparisons:
            if m.matches(
                    target,
                    m.ComparisonTarget(comparator=m.Name("False"),
                                       operator=m.Equal()),
            ):
                return cst.UnaryOperation(operator=cst.Not(),
                                          expression=original_node.left)

            if not m.matches(
                    target,
                    m.ComparisonTarget(comparator=m.Name("True"),
                                       operator=m.Equal()),
            ):
                remaining_targets.append(target)

        # FIXME: Explicitly check for `a == False == True ...` case and
        # short-circuit it to `not a`.

        if not remaining_targets:
            return original_node.left

        return updated_node.with_changes(comparisons=remaining_targets)
Exemplo n.º 6
0
    def test_lambda_metadata_matcher(self) -> None:
        # Match on qualified name provider
        module = cst.parse_module(
            "from typing import List\n\ndef foo() -> None: pass\n")
        wrapper = cst.MetadataWrapper(module)
        functiondef = cst.ensure_type(wrapper.module.body[1], cst.FunctionDef)

        self.assertTrue(
            matches(
                functiondef,
                m.FunctionDef(name=m.MatchMetadataIfTrue(
                    meta.QualifiedNameProvider,
                    lambda qualnames: any(n.name in {"foo", "bar", "baz"}
                                          for n in qualnames),
                )),
                metadata_resolver=wrapper,
            ))

        self.assertFalse(
            matches(
                functiondef,
                m.FunctionDef(name=m.MatchMetadataIfTrue(
                    meta.QualifiedNameProvider,
                    lambda qualnames: any(n.name in {"bar", "baz"}
                                          for n in qualnames),
                )),
                metadata_resolver=wrapper,
            ))
Exemplo n.º 7
0
 def _has_none(node):
     if m.matches(node, m.Name("None")):
         return True
     elif m.matches(node, m.BinaryOperation()):
         return _has_none(node.left) or _has_none(node.right)
     else:
         return False
Exemplo n.º 8
0
 def test_and_matcher_false(self) -> None:
     # Fail to match since True and False cannot match.
     self.assertFalse(
         matches(cst.Name("None"), m.AllOf(m.Name("True"), m.Name("False")))
     )
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=m.AllOf(
                     (m.Arg(), m.Arg(), m.Arg()),
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                 ),
             ),
         )
     )
Exemplo n.º 9
0
 def test_at_least_n_matcher_args_true(self) -> None:
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two wildcard arguments after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two arguements are integers of any value
     # after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two arguements that are integers with the
     # value 2 or 3 after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2),
                 ),
             ),
         )
     )
Exemplo n.º 10
0
    def new_obf_function_name(self, func: cst.Call):

        func_name = func.func

        # Обфускация имени функции
        if m.matches(func_name, m.Attribute()):
            func_name = cst.ensure_type(func_name, cst.Attribute)

            # Переименовывание имени
            if self.change_variables:
                func_name = func_name.with_changes(
                    value=self.obf_universal(func_name.value, 'v'))

            # Переименовывание метода
            if self.change_methods:
                func_name = func_name.with_changes(
                    attr=self.obf_universal(func_name.attr, 'cf'))

        elif m.matches(func_name, m.Name()):
            func_name = cst.ensure_type(func_name, cst.Name)
            if (self.change_functions
                    or self.change_classes) and self.can_rename(
                        func_name.value, 'c', 'f'):
                func_name = self.get_new_cst_name(func_name.value)

        else:
            pass

        func = func.with_changes(func=func_name)

        return func
Exemplo n.º 11
0
    def obf_function_args(self, func: cst.Call):

        new_args = []
        func_root = func.func
        func_name = ''

        if m.matches(func_root, m.Name()):
            func_name = cst.ensure_type(func_root, cst.Name).value
        elif m.matches(func_root, m.Attribute()):
            func_name = split_attribute(
                cst.ensure_type(func_root, cst.Attribute))[-1]

        if self.change_arguments or self.change_method_arguments:

            for arg in func.args:
                # Значения аргументов
                arg = arg.with_changes(value=self.obf_universal(arg.value))
                # Имена аргументов
                if arg.keyword is not None and self.can_rename_func_param(
                        arg.keyword.value, func_name):
                    new_keyword = self.get_new_cst_name(
                        arg.keyword) if arg.keyword is not None else None
                    arg = arg.with_changes(keyword=new_keyword)

                new_args.append(arg)

        func = func.with_changes(args=new_args)

        return func
Exemplo n.º 12
0
 def leave_AnnAssign(
     self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
 ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]:
     # It handles a special case where a type-annotated variable has not initialized, e.g. foo: str
     # This case will be converted to foo = ... so that nodes traversal won't encounter exceptions later on
     if match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Name(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     # Handles type-annotated class attributes that has not been initialized, e.g. self.foo: str
     elif match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Attribute(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     else:
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=original_node.value)
     return updated_node
Exemplo n.º 13
0
    def obf_function_name(self, func: cst.Call, updated_node):

        func_name = func.func

        # Обфускация имени функции
        if m.matches(
                func_name,
                m.Attribute()) and self.change_methods and self.can_rename(
                    func_name.attr.value, 'cf'):
            func_name = cst.ensure_type(func_name, cst.Attribute)
            func_name = func_name.with_changes(
                attr=self.get_new_cst_name(func_name.attr))
            updated_node = updated_node.with_changes(func=func_name)

        elif m.matches(func_name, m.Name()) and (
                self.change_functions
                and self.can_rename(func_name.value, 'f') or
                self.change_classes and self.can_rename(func_name.value, 'c')):
            func_name = cst.ensure_type(func_name, cst.Name)
            func_name = self.get_new_cst_name(func_name.value)
            updated_node = updated_node.with_changes(func=func_name)

        else:
            pass

        return updated_node
Exemplo n.º 14
0
 def test_and_operator_matcher_true(self) -> None:
     # Match on True identifier in roundabout way.
     self.assertTrue(
         matches(cst.Name("True"), m.Name() & m.Name(value=m.MatchRegex(r"True")))
     )
     # Match in a really roundabout way that verifies the __or__ behavior on
     # AllOf itself.
     self.assertTrue(
         matches(
             cst.Name("True"),
             m.Name() & m.Name(value=m.MatchRegex(r"True")) & m.Name("True"),
         )
     )
     # Verify that MatchIfTrue works with __and__ behavior properly.
     self.assertTrue(
         matches(
             cst.Name("True"),
             m.MatchIfTrue(lambda x: isinstance(x, cst.Name))
             & m.Name(value=m.MatchRegex(r"True")),
         )
     )
     self.assertTrue(
         matches(
             cst.Name("True"),
             m.Name(value=m.MatchRegex(r"True"))
             & m.MatchIfTrue(lambda x: isinstance(x, cst.Name)),
         )
     )
Exemplo n.º 15
0
 def _get_async_expr_replacement(
         self, node: cst.CSTNode) -> Optional[cst.CSTNode]:
     if m.matches(node, m.Call()):
         node = cast(cst.Call, node)
         return self._get_async_call_replacement(node)
     elif m.matches(node, m.Attribute()):
         node = cast(cst.Attribute, node)
         return self._get_async_attr_replacement(node)
     elif m.matches(node, m.UnaryOperation(operator=m.Not())):
         node = cast(cst.UnaryOperation, node)
         replacement_expression = self._get_async_expr_replacement(
             node.expression)
         if replacement_expression is not None:
             return node.with_changes(expression=replacement_expression)
     elif m.matches(node, m.BooleanOperation()):
         node = cast(cst.BooleanOperation, node)
         maybe_left = self._get_async_expr_replacement(node.left)
         maybe_right = self._get_async_expr_replacement(node.right)
         if maybe_left is not None or maybe_right is not None:
             left_replacement = maybe_left if maybe_left is not None else node.left
             right_replacement = (maybe_right if maybe_right is not None
                                  else node.right)
             return node.with_changes(left=left_replacement,
                                      right=right_replacement)
     return None
Exemplo n.º 16
0
    def leave_Call(self, original_node: Call,
                   updated_node: Call) -> BaseExpression:
        """
        Remove the `weak` argument if present in the call.

        This is only changing calls with keyword arguments.
        """
        if self.disconnect_call_matchers and m.matches(
                updated_node, m.OneOf(*self.disconnect_call_matchers)):
            updated_args = []
            should_change = False
            last_comma = MaybeSentinel.DEFAULT
            # Keep all arguments except the one with the keyword `weak` (if present)
            for index, arg in enumerate(updated_node.args):
                if m.matches(arg, m.Arg(keyword=m.Name("weak"))):
                    # An argument with the keyword `weak` was found
                    # -> we need to rewrite the statement
                    should_change = True
                else:
                    updated_args.append(arg)
                last_comma = arg.comma
            if should_change:
                # Make sure the end of line is formatted as initially
                updated_args[-1] = updated_args[-1].with_changes(
                    comma=last_comma)
                return updated_node.with_changes(args=updated_args)
        return super().leave_Call(original_node, updated_node)
Exemplo n.º 17
0
 def test_does_not_match_true(self) -> None:
     # Match on any call that takes one argument that isn't the value None.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("True")), )),
             m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))), )),
         ))
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=(m.DoesNotMatch(m.Arg(m.Name("None"))), )),
         ))
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("2")), ))),
         ))
     # Match any call that takes an argument which isn't True or False.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=(m.Arg(value=m.DoesNotMatch(
                 m.OneOf(m.Name("True"), m.Name("False")))), )),
         ))
     # Match any name node that doesn't match the regex for True
     self.assertTrue(
         matches(
             libcst.Name("False"),
             m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))),
         ))
Exemplo n.º 18
0
    def leave_ClassDef(self, original_node: cst.ClassDef,
                       updated_node: cst.ClassDef):

        self.class_stack.pop()

        if not self.change_classes:
            return updated_node

        class_name = updated_node.name.value
        new_bases = []

        if self.can_rename(class_name, 'c'):
            updated_node = self.renamed(updated_node)

        for base in updated_node.bases:
            full_name = base.value

            if m.matches(full_name, m.Name()):
                full_name = cst.ensure_type(full_name, cst.Name)
                if self.can_rename(full_name.value, 'c'):
                    base = base.with_changes(
                        value=self.get_new_cst_name(full_name.value))
            elif m.matches(full_name, m.Attribute()):
                # TODO поддержка импортов
                pass
            else:
                pass

            new_bases.append(base)

        updated_node = updated_node.with_changes(bases=new_bases)

        return updated_node
Exemplo n.º 19
0
 def test_and_matcher_true(self) -> None:
     # Match on True identifier in roundabout way.
     self.assertTrue(
         matches(
             cst.Name("True"), m.AllOf(m.Name(), m.Name(value=m.MatchRegex(r"True")))
         )
     )
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=m.AllOf(
                     (m.Arg(), m.Arg(), m.Arg()),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         )
     )
Exemplo n.º 20
0
 def test_simple_matcher_false(self) -> None:
     # Fail to match on a simple node based on the type and the position.
     node, wrapper = self._make_fixture("foo")
     self.assertFalse(
         matches(
             node,
             m.Name(
                 value="foo",
                 metadata=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((2, 0), (2, 3)),
                 ),
             ),
             metadata_resolver=wrapper,
         )
     )
     # Fail to match on any binary expression where the two children are in exact spots.
     node, wrapper = self._make_fixture("foo + bar")
     self.assertFalse(
         matches(
             node,
             m.BinaryOperation(
                 left=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((1, 0), (1, 1)),
                 ),
                 right=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((1, 4), (1, 5)),
                 ),
             ),
             metadata_resolver=wrapper,
         )
     )
Exemplo n.º 21
0
 def test_at_most_n_matcher_args_true(self) -> None:
     # Match a function call to "foo" with at most two arguments, both of which
     # are the integer 1.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Integer("1")), )),
             m.Call(func=m.Name("foo"),
                    args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )),
         ))
     # Match a function call to "foo" with at most two arguments, both of which
     # can be the integer 1 or 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.AtMostN(m.Arg(
                     m.OneOf(m.Integer("1"), m.Integer("2"))),
                                 n=2), ),
             ),
         ))
     # Match a function call to "foo" with at most two arguments, the first
     # one being the integer 1 and the second one, if included, being the
     # integer 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")),
                       m.ZeroOrOne(m.Arg(m.Integer("2")))),
             ),
         ))
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1 and the second one, if included, being the
     # integer 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")),
                       m.ZeroOrOne(m.Arg(m.Integer("2")))),
             ),
         ))
Exemplo n.º 22
0
 def test_or_operator_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(matches(cst.Name("None"), m.Name("True") | m.Name("False")))
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("None")),
             m.Assign(value=m.Name("True") | m.Name("False")),
         )
     )
Exemplo n.º 23
0
 def leave_Attribute(
     self, original: libcst.Attribute, updated: libcst.Attribute
 ) -> Any:
     if m.matches(updated.value, m.Name("six")):
         if m.matches(updated.attr, m.Name()):
             if updated.attr.value in _IO_OBJECTS:
                 return libcst.Attribute(
                     value=libcst.Name("io"),
                     attr=updated.attr,
                 )
     return updated
Exemplo n.º 24
0
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
                                  updated_node: cst.SimpleStatementLine):
        if match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.Assign(targets=[
                        match.AssignTarget(target=match.Name(
                            value=match.DoNotCare()))
                    ])
                ])):
            t = self.__get_var_type_assign_t(
                original_node.body[0].targets[0].target.value)

            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(
                            target=original_node.body[0].targets[0].target,
                            value=original_node.body[0].value,
                            annotation=t_annot_node,
                            equal=cst.AssignEqual(
                                whitespace_after=original_node.body[0].
                                targets[0].whitespace_after_equal,
                                whitespace_before=original_node.body[0].
                                targets[0].whitespace_before_equal))
                    ])
        elif match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.AnnAssign(target=match.Name(value=match.DoNotCare()))
                ])):
            t = self.__get_var_type_an_assign(
                original_node.body[0].target.value)
            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(target=original_node.body[0].target,
                                      value=original_node.body[0].value,
                                      annotation=t_annot_node,
                                      equal=original_node.body[0].equal)
                    ])

        return original_node
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("list") | m.Name("set") | m.Name("dict"),
                    args=[m.Arg(value=m.GeneratorExp() | m.ListComp())],
                ),
        ):
            call_name = cst.ensure_type(node.func, cst.Name).value

            if m.matches(node.args[0].value, m.GeneratorExp()):
                exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp)
                message_formatter = UNNECESSARY_GENERATOR
            else:
                exp = cst.ensure_type(node.args[0].value, cst.ListComp)
                message_formatter = UNNECESSARY_LIST_COMPREHENSION

            replacement = None
            if call_name == "list":
                replacement = node.deep_replace(
                    node, cst.ListComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "set":
                replacement = node.deep_replace(
                    node, cst.SetComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "dict":
                elt = exp.elt
                key = None
                value = None
                if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.Tuple)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.List)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                else:
                    # Unrecoginized form
                    return

                replacement = node.deep_replace(
                    node,
                    # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st
                    #  param but got `BaseExpression`.
                    cst.DictComp(key=key, value=value, for_in=exp.for_in),
                )

            self.report(node,
                        message_formatter.format(func=call_name),
                        replacement=replacement)
Exemplo n.º 26
0
 def test_at_least_n_matcher_no_args_false(self) -> None:
     # Fail to match a function call to "foo" with at least four arguments.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)),
         )
     )
     # Fail to match a function call to "foo" with at least four arguments,
     # the first one being the value 1.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3))
             ),
         )
     )
     # Fail to match a function call to "foo" with at least three arguments,
     # the last one being the value 2.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2")))
             ),
         )
     )
Exemplo n.º 27
0
    def _split_module(
        self, orig_module: libcst.Module, updated_module: libcst.Module
    ) -> Tuple[List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]],
               List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]], List[Union[
                              libcst.SimpleStatementLine,
                              libcst.BaseCompoundStatement]], ]:
        statement_before_import_location = 0
        import_add_location = 0

        # never insert an import before initial __strict__ flag
        if m.matches(
                orig_module,
                m.Module(body=[
                    m.SimpleStatementLine(body=[
                        m.Assign(targets=[
                            m.AssignTarget(target=m.Name("__strict__"))
                        ])
                    ]),
                    m.ZeroOrMore(),
                ]),
        ):
            statement_before_import_location = import_add_location = 1

        # This works under the principle that while we might modify node contents,
        # we have yet to modify the number of statements. So we can match on the
        # original tree but break up the statements of the modified tree. If we
        # change this assumption in this visitor, we will have to change this code.
        for i, statement in enumerate(orig_module.body):
            if m.matches(
                    statement,
                    m.SimpleStatementLine(
                        body=[m.Expr(value=m.SimpleString())])):
                statement_before_import_location = import_add_location = 1
            elif isinstance(statement, libcst.SimpleStatementLine):
                for possible_import in statement.body:
                    for last_import in self.all_imports:
                        if possible_import is last_import:
                            import_add_location = i + 1
                            break

        return (
            list(updated_module.body[:statement_before_import_location]),
            list(updated_module.
                 body[statement_before_import_location:import_add_location]),
            list(updated_module.body[import_add_location:]),
        )
Exemplo n.º 28
0
    def visit_Call(self, node: cst.Call) -> Optional[bool]:
        if m.matches(node, gen_task_matcher):
            raise TransformError(
                "gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task) from tornado 2.4.1 is unsupported by this codemod. This file has not been modified. Manually update to supported syntax before running again."
            )

        return True
Exemplo n.º 29
0
    def leave_Yield(self, node: cst.Yield,
                    updated_node: cst.Yield) -> Union[cst.Await, cst.Yield]:
        if not self.in_coroutine(self.coroutine_stack):
            return updated_node

        if not isinstance(updated_node.value, cst.BaseExpression):
            return updated_node

        if isinstance(updated_node.value, (cst.List, cst.ListComp)):
            self.required_imports.add("asyncio")
            expression = self.pluck_asyncio_gather_expression_from_yield_list_or_list_comp(
                updated_node)

        elif m.matches(
                updated_node,
                m.Yield(value=((m.Dict() | m.DictComp()))
                        | m.Call(func=m.Name("dict"))),
        ):
            raise TransformError(
                "Yielding a dict of futures (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) added in tornado 3.2 is unsupported by the codemod. This file has not been modified. Manually update to supported syntax before running again."
            )

        else:
            expression = updated_node.value

        return cst.Await(
            expression=expression,
            whitespace_after_await=updated_node.whitespace_after_yield,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
        )
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                self,
                decorator,
                QualifiedName(
                    name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT
                ),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.
                if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args)
                        + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(value=""),
                                    whitespace_after=SimpleWhitespace(value=""),
                                ),
                            )
                        ],
                    )
                    self.report(d, replacement=d.with_changes(decorator=new_decorator))