def __extract_names_multi_assign(self, elements): # Add self vars. in tuple assignments, e.g. self.x, self.y = 1, 2 # Adds variables in tuple(s) in multiple assignments, e.g. a, (b, c) = 1, (2, 3) names: List[cst.Name] = [] i = 0 while i < len(elements): if match.matches( elements[i], match.Element(value=match.Name(value=match.DoNotCare()))): names.append(elements[i].value) elif match.matches( elements[i], match.Element(value=match.Attribute(attr=match.Name( value=match.DoNotCare())))): names.append(elements[i].value) elif match.matches( elements[i], match.Element(value=match.Tuple( elements=match.DoNotCare()))): elements.extend( match.findall( elements[i].value, match.Element(value=match.OneOf( match.Attribute(attr=match.Name( value=match.DoNotCare())), match.Name(value=match.DoNotCare()))))) i += 1 return names
def process_variable(self, node: Union[cst.BaseExpression, cst.BaseAssignTargetExpression]): if m.matches(node, m.Name()): node = cst.ensure_type(node, cst.Name) if self.class_stack and not self.function_stack: self.class_stack[-1].variables.append(node.value) else: self.info.variables.append(node.value) elif m.matches(node, m.Attribute()): node = cst.ensure_type(node, cst.Attribute) splitted_attributes = split_attribute(node) if splitted_attributes[ 0] == 'self' and self.class_stack and self.function_stack and len( splitted_attributes) > 1: self.class_stack[-1].variables.append(splitted_attributes[1]) else: self.info.variables.append(splitted_attributes[0]) elif m.matches(node, m.Tuple()): node = cst.ensure_type(node, cst.Tuple) for el in node.elements: self.process_variable(el.value) else: pass
def _func_name(self, func): if m.matches(func, m.Name()): return func.value elif m.matches(func, m.Attribute()): return func.attr.value else: return 'func'
def test_at_least_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguments after that are # strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.SimpleString()), n=2), ), ), )) # Fail to match a function call to "foo" where the first argument is the integer # value 1, and there are at least three wildcard arguments after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)), ), )) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguements that are integers with # the value 2 after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer("2")), n=2), ), ), ))
def leave_Comparison(self, original_node: cst.Comparison, updated_node: cst.Comparison) -> cst.BaseExpression: remaining_targets: List[cst.ComparisonTarget] = [] for target in original_node.comparisons: if m.matches( target, m.ComparisonTarget(comparator=m.Name("False"), operator=m.Equal()), ): return cst.UnaryOperation(operator=cst.Not(), expression=original_node.left) if not m.matches( target, m.ComparisonTarget(comparator=m.Name("True"), operator=m.Equal()), ): remaining_targets.append(target) # FIXME: Explicitly check for `a == False == True ...` case and # short-circuit it to `not a`. if not remaining_targets: return original_node.left return updated_node.with_changes(comparisons=remaining_targets)
def test_lambda_metadata_matcher(self) -> None: # Match on qualified name provider module = cst.parse_module( "from typing import List\n\ndef foo() -> None: pass\n") wrapper = cst.MetadataWrapper(module) functiondef = cst.ensure_type(wrapper.module.body[1], cst.FunctionDef) self.assertTrue( matches( functiondef, m.FunctionDef(name=m.MatchMetadataIfTrue( meta.QualifiedNameProvider, lambda qualnames: any(n.name in {"foo", "bar", "baz"} for n in qualnames), )), metadata_resolver=wrapper, )) self.assertFalse( matches( functiondef, m.FunctionDef(name=m.MatchMetadataIfTrue( meta.QualifiedNameProvider, lambda qualnames: any(n.name in {"bar", "baz"} for n in qualnames), )), metadata_resolver=wrapper, ))
def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False
def test_and_matcher_false(self) -> None: # Fail to match since True and False cannot match. self.assertFalse( matches(cst.Name("None"), m.AllOf(m.Name("True"), m.Name("False"))) ) self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ), ), ) )
def test_at_least_n_matcher_args_true(self) -> None: # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two wildcard arguments after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements are integers of any value # after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements that are integers with the # value 2 or 3 after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2), ), ), ) )
def new_obf_function_name(self, func: cst.Call): func_name = func.func # Обфускация имени функции if m.matches(func_name, m.Attribute()): func_name = cst.ensure_type(func_name, cst.Attribute) # Переименовывание имени if self.change_variables: func_name = func_name.with_changes( value=self.obf_universal(func_name.value, 'v')) # Переименовывание метода if self.change_methods: func_name = func_name.with_changes( attr=self.obf_universal(func_name.attr, 'cf')) elif m.matches(func_name, m.Name()): func_name = cst.ensure_type(func_name, cst.Name) if (self.change_functions or self.change_classes) and self.can_rename( func_name.value, 'c', 'f'): func_name = self.get_new_cst_name(func_name.value) else: pass func = func.with_changes(func=func_name) return func
def obf_function_args(self, func: cst.Call): new_args = [] func_root = func.func func_name = '' if m.matches(func_root, m.Name()): func_name = cst.ensure_type(func_root, cst.Name).value elif m.matches(func_root, m.Attribute()): func_name = split_attribute( cst.ensure_type(func_root, cst.Attribute))[-1] if self.change_arguments or self.change_method_arguments: for arg in func.args: # Значения аргументов arg = arg.with_changes(value=self.obf_universal(arg.value)) # Имена аргументов if arg.keyword is not None and self.can_rename_func_param( arg.keyword.value, func_name): new_keyword = self.get_new_cst_name( arg.keyword) if arg.keyword is not None else None arg = arg.with_changes(keyword=new_keyword) new_args.append(arg) func = func.with_changes(args=new_args) return func
def leave_AnnAssign( self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]: # It handles a special case where a type-annotated variable has not initialized, e.g. foo: str # This case will be converted to foo = ... so that nodes traversal won't encounter exceptions later on if match.matches( original_node, match.AnnAssign( target=match.Name(value=match.DoNotCare()), annotation=match.Annotation(annotation=match.DoNotCare()), value=None)): updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=cst.Ellipsis()) # Handles type-annotated class attributes that has not been initialized, e.g. self.foo: str elif match.matches( original_node, match.AnnAssign( target=match.Attribute(value=match.DoNotCare()), annotation=match.Annotation(annotation=match.DoNotCare()), value=None)): updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=cst.Ellipsis()) else: updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=original_node.value) return updated_node
def obf_function_name(self, func: cst.Call, updated_node): func_name = func.func # Обфускация имени функции if m.matches( func_name, m.Attribute()) and self.change_methods and self.can_rename( func_name.attr.value, 'cf'): func_name = cst.ensure_type(func_name, cst.Attribute) func_name = func_name.with_changes( attr=self.get_new_cst_name(func_name.attr)) updated_node = updated_node.with_changes(func=func_name) elif m.matches(func_name, m.Name()) and ( self.change_functions and self.can_rename(func_name.value, 'f') or self.change_classes and self.can_rename(func_name.value, 'c')): func_name = cst.ensure_type(func_name, cst.Name) func_name = self.get_new_cst_name(func_name.value) updated_node = updated_node.with_changes(func=func_name) else: pass return updated_node
def test_and_operator_matcher_true(self) -> None: # Match on True identifier in roundabout way. self.assertTrue( matches(cst.Name("True"), m.Name() & m.Name(value=m.MatchRegex(r"True"))) ) # Match in a really roundabout way that verifies the __or__ behavior on # AllOf itself. self.assertTrue( matches( cst.Name("True"), m.Name() & m.Name(value=m.MatchRegex(r"True")) & m.Name("True"), ) ) # Verify that MatchIfTrue works with __and__ behavior properly. self.assertTrue( matches( cst.Name("True"), m.MatchIfTrue(lambda x: isinstance(x, cst.Name)) & m.Name(value=m.MatchRegex(r"True")), ) ) self.assertTrue( matches( cst.Name("True"), m.Name(value=m.MatchRegex(r"True")) & m.MatchIfTrue(lambda x: isinstance(x, cst.Name)), ) )
def _get_async_expr_replacement( self, node: cst.CSTNode) -> Optional[cst.CSTNode]: if m.matches(node, m.Call()): node = cast(cst.Call, node) return self._get_async_call_replacement(node) elif m.matches(node, m.Attribute()): node = cast(cst.Attribute, node) return self._get_async_attr_replacement(node) elif m.matches(node, m.UnaryOperation(operator=m.Not())): node = cast(cst.UnaryOperation, node) replacement_expression = self._get_async_expr_replacement( node.expression) if replacement_expression is not None: return node.with_changes(expression=replacement_expression) elif m.matches(node, m.BooleanOperation()): node = cast(cst.BooleanOperation, node) maybe_left = self._get_async_expr_replacement(node.left) maybe_right = self._get_async_expr_replacement(node.right) if maybe_left is not None or maybe_right is not None: left_replacement = maybe_left if maybe_left is not None else node.left right_replacement = (maybe_right if maybe_right is not None else node.right) return node.with_changes(left=left_replacement, right=right_replacement) return None
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: """ Remove the `weak` argument if present in the call. This is only changing calls with keyword arguments. """ if self.disconnect_call_matchers and m.matches( updated_node, m.OneOf(*self.disconnect_call_matchers)): updated_args = [] should_change = False last_comma = MaybeSentinel.DEFAULT # Keep all arguments except the one with the keyword `weak` (if present) for index, arg in enumerate(updated_node.args): if m.matches(arg, m.Arg(keyword=m.Name("weak"))): # An argument with the keyword `weak` was found # -> we need to rewrite the statement should_change = True else: updated_args.append(arg) last_comma = arg.comma if should_change: # Make sure the end of line is formatted as initially updated_args[-1] = updated_args[-1].with_changes( comma=last_comma) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
def test_does_not_match_true(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.DoesNotMatch(m.Arg(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("2")), ))), )) # Match any call that takes an argument which isn't True or False. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch( m.OneOf(m.Name("True"), m.Name("False")))), )), )) # Match any name node that doesn't match the regex for True self.assertTrue( matches( libcst.Name("False"), m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))), ))
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef): self.class_stack.pop() if not self.change_classes: return updated_node class_name = updated_node.name.value new_bases = [] if self.can_rename(class_name, 'c'): updated_node = self.renamed(updated_node) for base in updated_node.bases: full_name = base.value if m.matches(full_name, m.Name()): full_name = cst.ensure_type(full_name, cst.Name) if self.can_rename(full_name.value, 'c'): base = base.with_changes( value=self.get_new_cst_name(full_name.value)) elif m.matches(full_name, m.Attribute()): # TODO поддержка импортов pass else: pass new_bases.append(base) updated_node = updated_node.with_changes(bases=new_bases) return updated_node
def test_and_matcher_true(self) -> None: # Match on True identifier in roundabout way. self.assertTrue( matches( cst.Name("True"), m.AllOf(m.Name(), m.Name(value=m.MatchRegex(r"True"))) ) ) self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
def test_simple_matcher_false(self) -> None: # Fail to match on a simple node based on the type and the position. node, wrapper = self._make_fixture("foo") self.assertFalse( matches( node, m.Name( value="foo", metadata=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((2, 0), (2, 3)), ), ), metadata_resolver=wrapper, ) ) # Fail to match on any binary expression where the two children are in exact spots. node, wrapper = self._make_fixture("foo + bar") self.assertFalse( matches( node, m.BinaryOperation( left=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 0), (1, 1)), ), right=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 4), (1, 5)), ), ), metadata_resolver=wrapper, ) )
def test_at_most_n_matcher_args_true(self) -> None: # Match a function call to "foo" with at most two arguments, both of which # are the integer 1. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(func=m.Name("foo"), args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )), )) # Match a function call to "foo" with at most two arguments, both of which # can be the integer 1 or 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(m.Arg( m.OneOf(m.Integer("1"), m.Integer("2"))), n=2), ), ), )) # Match a function call to "foo" with at most two arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), ))
def test_or_operator_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse(matches(cst.Name("None"), m.Name("True") | m.Name("False"))) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("None")), m.Assign(value=m.Name("True") | m.Name("False")), ) )
def leave_Attribute( self, original: libcst.Attribute, updated: libcst.Attribute ) -> Any: if m.matches(updated.value, m.Name("six")): if m.matches(updated.attr, m.Name()): if updated.attr.value in _IO_OBJECTS: return libcst.Attribute( value=libcst.Name("io"), attr=updated.attr, ) return updated
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine): if match.matches( original_node, match.SimpleStatementLine(body=[ match.Assign(targets=[ match.AssignTarget(target=match.Name( value=match.DoNotCare())) ]) ])): t = self.__get_var_type_assign_t( original_node.body[0].targets[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign( target=original_node.body[0].targets[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=cst.AssignEqual( whitespace_after=original_node.body[0]. targets[0].whitespace_after_equal, whitespace_before=original_node.body[0]. targets[0].whitespace_before_equal)) ]) elif match.matches( original_node, match.SimpleStatementLine(body=[ match.AnnAssign(target=match.Name(value=match.DoNotCare())) ])): t = self.__get_var_type_an_assign( original_node.body[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign(target=original_node.body[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=original_node.body[0].equal) ]) return original_node
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
def test_at_least_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at least four arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)), ) ) # Fail to match a function call to "foo" with at least four arguments, # the first one being the value 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3)) ), ) ) # Fail to match a function call to "foo" with at least three arguments, # the last one being the value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2"))) ), ) )
def _split_module( self, orig_module: libcst.Module, updated_module: libcst.Module ) -> Tuple[List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[ libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], ]: statement_before_import_location = 0 import_add_location = 0 # never insert an import before initial __strict__ flag if m.matches( orig_module, m.Module(body=[ m.SimpleStatementLine(body=[ m.Assign(targets=[ m.AssignTarget(target=m.Name("__strict__")) ]) ]), m.ZeroOrMore(), ]), ): statement_before_import_location = import_add_location = 1 # This works under the principle that while we might modify node contents, # we have yet to modify the number of statements. So we can match on the # original tree but break up the statements of the modified tree. If we # change this assumption in this visitor, we will have to change this code. for i, statement in enumerate(orig_module.body): if m.matches( statement, m.SimpleStatementLine( body=[m.Expr(value=m.SimpleString())])): statement_before_import_location = import_add_location = 1 elif isinstance(statement, libcst.SimpleStatementLine): for possible_import in statement.body: for last_import in self.all_imports: if possible_import is last_import: import_add_location = i + 1 break return ( list(updated_module.body[:statement_before_import_location]), list(updated_module. body[statement_before_import_location:import_add_location]), list(updated_module.body[import_add_location:]), )
def visit_Call(self, node: cst.Call) -> Optional[bool]: if m.matches(node, gen_task_matcher): raise TransformError( "gen.Task (https://www.tornadoweb.org/en/branch2.4/gen.html#tornado.gen.Task) from tornado 2.4.1 is unsupported by this codemod. This file has not been modified. Manually update to supported syntax before running again." ) return True
def leave_Yield(self, node: cst.Yield, updated_node: cst.Yield) -> Union[cst.Await, cst.Yield]: if not self.in_coroutine(self.coroutine_stack): return updated_node if not isinstance(updated_node.value, cst.BaseExpression): return updated_node if isinstance(updated_node.value, (cst.List, cst.ListComp)): self.required_imports.add("asyncio") expression = self.pluck_asyncio_gather_expression_from_yield_list_or_list_comp( updated_node) elif m.matches( updated_node, m.Yield(value=((m.Dict() | m.DictComp())) | m.Call(func=m.Name("dict"))), ): raise TransformError( "Yielding a dict of futures (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) added in tornado 3.2 is unsupported by the codemod. This file has not been modified. Manually update to supported syntax before running again." ) else: expression = updated_node.value return cst.Await( expression=expression, whitespace_after_await=updated_node.whitespace_after_yield, lpar=updated_node.lpar, rpar=updated_node.rpar, )
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))