def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: try: key = original.func.attr.value kword_params = self.METHOD_TO_PARAMS[key] except (AttributeError, KeyError): # Either not a method from the API or too convoluted to be sure. return updated # If the existing code is valid, keyword args come after positional args. # Therefore, all positional args must map to the first parameters. args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) if any(k.keyword.value == "request" for k in kwargs): # We've already fixed this file, don't fix it again. return updated kwargs, ctrl_kwargs = partition( lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs ) args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) request_arg = cst.Arg( value=cst.Dict([ cst.DictElement( cst.SimpleString("'{}'".format(name)), cst.Element(value=arg.value) ) # Note: the args + kwargs looks silly, but keep in mind that # the control parameters had to be stripped out, and that # those could have been passed positionally or by keyword. for name, arg in zip(kword_params, args + kwargs)]), keyword=cst.Name("request") ) return updated.with_changes( args=[request_arg] + ctrl_kwargs )
class ExpressionTest(UnitTest): @data_provider(( ("a string", "a string"), (cst.Name("a_name"), "a_name"), (cst.parse_expression("a.b.c"), "a.b.c"), (cst.parse_expression("a.b()"), "a.b"), (cst.parse_expression("a.b.c[i]"), "a.b.c"), (cst.parse_statement("def fun(): pass"), "fun"), (cst.parse_statement("class cls: pass"), "cls"), ( cst.Decorator( ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)), "a.b.c.d", ), (cst.parse_statement("(a.b()).c()"), None), # not a supported Node type )) def test_get_full_name_for_expression( self, input: Union[str, cst.CSTNode], output: Optional[str], ) -> None: self.assertEqual(get_full_name_for_node(input), output) if output is None: with self.assertRaises(Exception): get_full_name_for_node_or_raise(input) else: self.assertEqual(get_full_name_for_node_or_raise(input), output) def test_simplestring_evaluated_value(self) -> None: raw_string = '"a string."' node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString) self.assertEqual(node.value, raw_string) self.assertEqual(node.evaluated_value, literal_eval(raw_string)) def test_integer_evaluated_value(self) -> None: raw_value = "5" node = ensure_type(cst.parse_expression(raw_value), cst.Integer) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_float_evaluated_value(self) -> None: raw_value = "5.5" node = ensure_type(cst.parse_expression(raw_value), cst.Float) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_complex_evaluated_value(self) -> None: raw_value = "5j" node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value))
def _get_wrapped_union_type( node: cst.BaseExpression, addition: cst.SubscriptElement, *additions: cst.SubscriptElement, ) -> cst.Subscript: """ Take two or more nodes, wrap them in a union type. Function signature is explicitly defined as taking at least one addition for type safety. """ return cst.Subscript( cst.Name("Union"), [cst.SubscriptElement(cst.Index(node)), addition, *additions])
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def leave_StarImport( original_node: cst.ImportFrom, updated_node: cst.ImportFrom, imp: ImportFrom, ) -> Union[cst.ImportFrom, cst.RemovalSentinel]: if imp.modules: names_to_suggestion = [ cst.ImportAlias(cst.Name(module)) for module in imp.modules ] return updated_node.with_changes(names=names_to_suggestion) elif imp.module: return cst.RemoveFromParent() return original_node
def test_annotation(self) -> None: # Test that we can insert an annotation expression normally. statement = parse_template_statement( "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", ) # Test that we can insert an annotation node as a special case. statement = parse_template_statement( "x: {type} = {val}", type=cst.Annotation(cst.Name("int")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", )
def test_args(self) -> None: # Test that we can insert an argument into a function call normally. statement = parse_template_expression( "foo({arg1}, {arg2})", arg1=cst.Name("bar"), arg2=cst.Name("baz"), ) self.assertEqual( self.code(statement), "foo(bar, baz)", ) # Test that we can insert an argument as a special case. statement = parse_template_expression( "foo({arg1}, {arg2})", arg1=cst.Arg(cst.Name("bar")), arg2=cst.Arg(cst.Name("baz")), ) self.assertEqual( self.code(statement), "foo(bar, baz)", )
def test_at_most_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)), m.Call(m.Name("foo"), (m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)), m.Call(m.Name("foo"), [m.AtMostN(n=5), m.Arg(m.Integer("1"))]), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())), ) )
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Migrate form deprecated method AppendItem() if matchers.matches(updated_node, self.deprecated_call_matcher): updated_node = updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value="Append"))) # Update keywords if matchers.matches(updated_node, self.call_matcher): updated_node_args = list(updated_node.args) for arg_matcher, renamed in self.args_matchers_map.items(): for i, node_arg in enumerate(updated_node.args): if matchers.matches(node_arg, arg_matcher): updated_node_args[i] = node_arg.with_changes( keyword=cst.Name(value=renamed)) updated_node = updated_node.with_changes( args=updated_node_args) return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName(name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-\fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], # typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any( m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace( value=""), whitespace_after=SimpleWhitespace( value=""), ), ) ], ) self.report( d, replacement=d.with_changes(decorator=new_decorator))
def test_does_not_match_operator_false(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("None")),)), m.Call(args=(m.Arg(value=~m.Name("None")),)), ) ) self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)), m.Call(args=((~m.Arg(m.Integer("1"))),)), ) ) # Match any call that takes an argument which isn't True or False. self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")),)), m.Call(args=(m.Arg(value=~(m.Name("True") | m.Name("False"))),)), ) ) # Roundabout way of verifying ~(x&y) behavior. self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")),)), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))),)), ) ) # Roundabout way of verifying (~x)|(~y) behavior self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("True")),)), m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))),)), ) ) # Match any name node that doesn't match the regex for True self.assertFalse( matches(cst.Name("True"), m.Name(value=~m.MatchRegex(r"True"))) )
def _get_clean_type_from_subscript( aliases: List[Alias], typecst: cst.Subscript) -> cst.BaseExpression: # We can modify this as-is to add our extra values if not typecst.value.deep_equals(cst.Name("Union")): # Don't handle other types like "Literal", just widen them. return _get_clean_type_from_expression(aliases, typecst) name = _get_alias_name(typecst) value = typecst.with_changes(slice=[ *typecst.slice, _get_match_metadata(), _get_match_if_true(typecst) ]) return _wrap_clean_type(aliases, name, value)
def visit_Subscript(self, node: cst.Subscript) -> None: # If the current node is a MatchIfTrue, we don't want to modify it. if node.value.deep_equals(cst.Name("MatchIfTrue")): self.in_match_if_true.add(node) # If the direct descendant is a union, lets add it to be fixed up. elif node.value.deep_equals(cst.Name("Sequence")): if self.in_match_if_true: # We don't want to add AtLeastN/AtMostN inside MatchIfTrue # type blocks, even for sequence types. return slc = node.slice # TODO: We can remove the instance check after ExtSlice is deprecated. if not isinstance(slc, Sequence) or len(slc) != 1: raise Exception( "Unexpected number of sequence elements inside Sequence type " + "annotation!") nodeslice = slc[0].slice if isinstance(nodeslice, cst.Index): possibleunion = nodeslice.value if isinstance(possibleunion, cst.Subscript): if possibleunion.value.deep_equals(cst.Name("Union")): self.fixup_nodes.add(possibleunion)
def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue( matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False")))) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ), libcst.Name("True")), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), )) self.assertTrue( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ))
def body( self, ) -> typing.Iterable[typing.Union[cst.BaseCompoundStatement, cst.SimpleStatementLine]]: yield cst.SimpleStatementLine( [cst.ImportFrom(cst.Name("typing"), names=cst.ImportStar())]) yield from assign_properties(self.properties) yield from function_defs(self.function_overloads, self.functions, "function") for name, class_ in sort_items(self.classes): yield class_.class_def(name)
def test_or_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse( matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False")))) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ), libcst.Name("None")), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), )) self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("4")), m.Arg(m.Integer("5")), m.Arg(m.Integer("6")), ), ), ), ))
def test_assign_target(self) -> None: # Test that we can insert an assignment target normally. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.Name("first"), b=cst.Name("second"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", ) # Test that we can insert an assignment target as a special case. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.AssignTarget(cst.Name("first")), b=cst.AssignTarget(cst.Name("second")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", )
def test_findall_with_metadata_wrapper(self) -> None: # Find all assignments in a tree code = """ a = 1 b = True def foo(bar: int) -> bool: return False """ module = cst.parse_module(dedent(code)) wrapper = meta.MetadataWrapper(module) # Test that when we find over a wrapper, we implicitly use it for # metadata as well as traversal. booleans = findall( wrapper, m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE), ) self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")]) # Test that we can provide an explicit resolver and tree booleans = findall( wrapper.module, m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE), metadata_resolver=wrapper, ) self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")]) # Test that failing to provide metadata leads to no match booleans = findall( wrapper.module, m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE), ) self.assertNodeSequenceEqual(booleans, [])
def test_findall_with_transformers(self) -> None: # Find all assignments in a tree class TestTransformer(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( meta.ExpressionContextProvider, ) def __init__(self) -> None: super().__init__() self.results: Sequence[cst.CSTNode] = () def visit_Module(self, node: cst.Module) -> None: self.results = self.findall( node, m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE), ) code = """ a = 1 b = True def foo(bar: int) -> bool: return False """ module = cst.parse_module(dedent(code)) wrapper = meta.MetadataWrapper(module) visitor = TestTransformer() wrapper.visit(visitor) self.assertNodeSequenceEqual( visitor.results, [ cst.Name("a"), cst.Name("b"), cst.Name("foo"), cst.Name("bar"), ], )
def _get_clean_type_from_subscript( aliases: List[Alias], typecst: cst.Subscript) -> cst.BaseExpression: if typecst.value.deep_equals(cst.Name("Sequence")): # Lets attempt to widen the sequence type and alias it. slc = typecst.slice # TODO: This instance check can go away once we deprecate ExtSlice if not isinstance(slc, Sequence): raise Exception("Logic error, expected Sequence to have children!") if len(slc) != 1: raise Exception( "Logic error, Sequence shouldn't have more than one param!") inner_type = slc[0].slice if not isinstance(inner_type, cst.Index): raise Exception( "Logic error, expecting Index for only Sequence element!") inner_type = inner_type.value if isinstance(inner_type, cst.Subscript): clean_inner_type = _get_clean_type_from_subscript( aliases, inner_type) elif isinstance(inner_type, (cst.Name, cst.SimpleString)): clean_inner_type = _get_clean_type_from_expression( aliases, inner_type) else: raise Exception("Logic error, unexpected type in Sequence!") return _get_wrapped_union_type( typecst.deep_replace(inner_type, clean_inner_type), _get_do_not_care(), _get_match_if_true(typecst), ) # We can modify this as-is to add our extra values elif typecst.value.deep_equals(cst.Name("Union")): return _get_clean_type_from_union(aliases, typecst) else: # Don't handle other types like "Literal", just widen them. return _get_clean_type_from_expression(aliases, typecst)
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.ExtSlice: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.ExtSlice( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), cst.Index( cst.Subscript( cst.Name("Callable"), slice=[ cst.ExtSlice( cst.Index( cst.List( [ cst.Element( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes( oldtype ) ) ] ) ) ), cst.ExtSlice(cst.Index(cst.Name("bool"))), ], ) ), ) ) )
def test_zero_or_more_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguments are strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))), ), ) ) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguements are integers with the # value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))), ), ) )
def leave_StarImport(original_node: cst.ImportFrom, updated_node: cst.ImportFrom, **kwargs) -> Union[cst.ImportFrom, RemovalSentinel]: imp = kwargs["imp"] if imp["modules"]: modules = ",".join(imp["modules"]) names_to_suggestion = [] for module in modules.split(","): names_to_suggestion.append(cst.ImportAlias(cst.Name(module))) return updated_node.with_changes(names=names_to_suggestion) else: if imp["module"]: return cst.RemoveFromParent() return original_node
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
def leave_Attribute(self, orig: cst.Attribute, updated: cst.Attribute) -> cst.BaseExpression: if isinstance(updated.value, cst.Name): stem = updated.value.value # Strip out "self." references. In the ISS code, a field in the # instruction appears as self.field_name. In the documentation, we # can treat all the instruction fields as being in scope. if stem == 'self': return updated.attr # Replace state.dmem with DMEM. This is an object in the ISS code, # so you see things like state.dmem.load_u32(...). We keep the # "object-orientated" style (so DMEM.load_u32(...)) because we need # to distinguish between 32-bit and 256-bit loads. if stem == 'state' and updated.attr.value == 'dmem': return cst.Name(value='DMEM') if isinstance(updated.value, cst.Attribute): # This attribute looks like A.B.C where B, C are names and A may be # a further attribute or it might be a name. attr_a = updated.value.value attr_b = updated.value.attr.value attr_c = updated.attr.value if isinstance(attr_a, cst.Name): stem = attr_a.value # Replace state.csrs.flags with FLAGs: the flag groups are # stored in the CSRs in the ISS and the implementation, but # logically exist somewhat separately, so we want named # reads/writes from them to look different. if (stem, attr_b, attr_c) == ('state', 'csrs', 'flags'): return cst.Name(value='FLAGs') return updated
def leave_IfExp( self, original_node: cst.IfExp, updated_node: cst.IfExp, ): return cst.Call( func=cst.Name(value=self.phi_name), args=[ cst.Arg(value=v) for v in ( updated_node.test, updated_node.body, updated_node.orelse ) ], )
def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Sequence")): slc = updated_node.slice # TODO: We can remove the instance check after ExtSlice is deprecated. if not isinstance(slc, Sequence) or len(slc) != 1: raise Exception( "Unexpected number of sequence elements inside Sequence type " + "annotation!") nodeslice = slc[0].slice if isinstance(nodeslice, cst.Index): possibleunion = nodeslice.value if isinstance(possibleunion, cst.Subscript): # Special case for Sequence[Union] so that we make more collapsed # types. if possibleunion.value.deep_equals(cst.Name("Union")): return updated_node.with_deep_changes( possibleunion, slice=[ *possibleunion.slice, _get_do_not_care(), _get_match_metadata(), ], ) # This is a sequence of some node, add DoNotCareSentinel here so that # a person can add a do not care to a sequence that otherwise has # valid matcher nodes. return updated_node.with_changes(slice=(cst.SubscriptElement( cst.Index( _get_wrapped_union_type( nodeslice.value, _get_do_not_care(), _get_match_metadata(), ))), )) raise Exception("Unexpected slice type for Sequence!") return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Matches calls with symbols without the wx prefix for symbol, matcher, renamed in self.matchers_short_map: if symbol in self.wx_imports and matchers.matches( updated_node, matcher): # Remove the symbol's import RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) # Add import of top level wx package AddImportsVisitor.add_needed_import(self.context, "wx") # Return updated node if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes(func=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value=renamed))) # Matches full calls like wx.MySymbol for matcher, renamed in self.matchers_full_map: if matchers.matches(updated_node, matcher): if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value=renamed))) # Returns updated node return updated_node
def test_simple_module(self) -> None: module = parse_template_module( self.dedent( """ from {module} import {obj} def foo() -> {obj}: return {obj}() """ ), module=cst.Name("foo"), obj=cst.Name("Bar"), ) self.assertEqual( module.code, self.dedent( """ from foo import Bar def foo() -> Bar: return Bar() """ ), )
def _get_alias_name(node: cst.CSTNode) -> Optional[str]: if isinstance(node, (cst.Name, cst.SimpleString)): return f"{_get_raw_name(node)}MatchType" elif isinstance(node, cst.Subscript): if node.value.deep_equals(cst.Name("Union")): slc = node.slice # TODO: This instance check can go away once we deprecate ExtSlice if isinstance(slc, Sequence): names = [_get_raw_name(s) for s in slc] if any(n is None for n in names): return None return "Or".join(n for n in names if n is not None) + "MatchType" return None