def test_decorators(self) -> None: # Test that we can special-case decorators when needed. statement = parse_template_statement( "@{decorator}\ndef foo(): pass\n", decorator=cst.Name("bar"), ) self.assertEqual( self.code(statement), "@bar\ndef foo(): pass\n", ) statement = parse_template_statement( "@{decorator}\ndef foo(): pass\n", decorator=cst.Decorator(cst.Name("bar")), ) self.assertEqual( self.code(statement), "@bar\ndef foo(): pass\n", )
def test_parameters(self) -> None: # Test that we can insert a parameter into a function def normally. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Name("bar"), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameter as a special case. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Param(cst.Name("bar")), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameters list as a special case. statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( (cst.Param(cst.Name("bar")),), ), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test filling out multiple parameters statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( params=( cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")), ), star_kwarg=cst.Param(cst.Name("rest")), ), ) self.assertEqual( self.code(statement), "def foo(bar, baz, **rest): pass\n", )
def test_annotation(self) -> None: # Test that we can insert an annotation expression normally. statement = parse_template_statement( "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", ) # Test that we can insert an annotation node as a special case. statement = parse_template_statement( "x: {type} = {val}", type=cst.Annotation(cst.Name("int")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", )
def test_simple_statement(self) -> None: statement = parse_template_statement( "assert {test}, {msg}\n", test=cst.Name("True"), msg=cst.SimpleString('"Somehow True is no longer True..."'), ) self.assertEqual( self.code(statement), 'assert True, "Somehow True is no longer True..."\n', )
def test_assign_target(self) -> None: # Test that we can insert an assignment target normally. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.Name("first"), b=cst.Name("second"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", ) # Test that we can insert an assignment target as a special case. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.AssignTarget(cst.Name("first")), b=cst.AssignTarget(cst.Name("second")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", )
def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: modified_defaults: List = [] mutable_args: List[Tuple[cst.Name, Union[cst.List, cst.Dict]]] = [] for param in updated_node.params.params: if not m.matches(param, m.Param(default=m.OneOf(m.List(), m.Dict()))): modified_defaults.append(param) continue # This line here is just for type checkers peace of mind, # since it cannot reason about variables from matchers result. if not isinstance(param.default, (cst.List, cst.Dict)): continue mutable_args.append((param.name, param.default)) modified_defaults.append( param.with_changes(default=cst.Name("None"), )) if not mutable_args: return original_node modified_params: cst.Parameters = updated_node.params.with_changes( params=modified_defaults) initializations: List[Union[ cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [ # We use generation by template here since construction of the # resulting 'if' can be burdensome due to many nested objects # involved. Additional line is attached so that we may control # exact spacing between generated statements. parse_template_statement( DEFAULT_INIT_TEMPLATE, config=self.module_config, arg=arg, init=init).with_changes(leading_lines=[EMPTY_LINE]) for arg, init in mutable_args ] # Docstring should always go right after the function definition, # so we take special care to insert our initializations after the # last docstring found. docstrings = takewhile(is_docstring, updated_node.body.body) function_code = dropwhile(is_docstring, updated_node.body.body) # It is not possible to insert empty line after the statement line, # because whitespace is owned by the next statement after it. stmt_with_empty_line = next(function_code).with_changes( leading_lines=[EMPTY_LINE]) modified_body = ( *docstrings, *initializations, stmt_with_empty_line, *function_code, ) return updated_node.with_changes( params=modified_params, body=updated_node.body.with_changes(body=modified_body), )