def test_mutiple_returns(self): source = ''' def f(x, y, z='foo'): if x: b = y + list(x) return b else: return z ''' ast_tree = ast.parse(shift_source(source)) expected_source = ''' def f(__ast_pe_var_4, __ast_pe_var_5, __ast_pe_var_6='foo'): if __ast_pe_var_4: __ast_pe_var_7 = __ast_pe_var_5 + list(__ast_pe_var_4) __ast_pe_var_8 = __ast_pe_var_7 break else: __ast_pe_var_8 = __ast_pe_var_6 break ''' inliner = Inliner(3, get_locals(ast_tree)) new_ast = inliner.visit(ast_tree) self.assertASTEqual(new_ast, ast.parse(shift_source(expected_source))) self.assertEqual(inliner.get_var_count(), 8) self.assertEqual(inliner.get_return_var(), '__ast_pe_var_8') self.assertEqual(inliner.get_bindings(), { 'x': '__ast_pe_var_4', 'y': '__ast_pe_var_5', 'z': '__ast_pe_var_6', 'b': '__ast_pe_var_7'})
def _inlined_fn(self, node): ''' Return a list of nodes, representing inlined function call, and a node, repesenting the variable that stores result. ''' is_known, fn = self._get_node_value_if_known(node.func) assert is_known fn_ast = fn_to_ast(fn).body[0] inliner = Inliner(self._var_count, get_locals(fn_ast)) fn_ast = inliner.visit(fn_ast) self._var_count = inliner.get_var_count() inlined_body = [] assert not node.kwargs and not node.starargs # TODO for callee_arg, fn_arg in zip(node.args, fn_ast.args.args): # setup mangled values before call # TODO - if callee_arg is "simple" - literal or name, # and is never assigned in inlined_body # then do not make an assignment, just use it in inlined_body inlined_body.append(ast.Assign( targets=[ast.Name(id=fn_arg.id, ctx=ast.Store())], value=callee_arg)) is_known, value = self._get_node_value_if_known(callee_arg) if is_known: # TODO - check that mutations are detected self._constants[fn_arg.id] = value inlined_code = self._visit(fn_ast.body) # optimize inlined code if isinstance(inlined_code[-1], ast.Break): # single return inlined_body.extend(inlined_code[:-1]) else: # multiple returns - wrap in "while" while_var = new_var_name(self) inlined_body.extend([ ast.Assign( targets=[ast.Name(id=while_var, ctx=ast.Store())], value=self._get_literal_node(True)), ast.While( test=ast.Name(id=while_var, ctx=ast.Load()), body=[ ast.Assign( targets=[ast.Name(id=while_var, ctx=ast.Store())], value=self._get_literal_node(False))] + inlined_code, orelse=[]) ]) all_nodes = inlined_body + \ [ast.Name(id=inliner.get_return_var(), ctx=ast.Load())] remove_assignments(all_nodes) return all_nodes[:-1], all_nodes[-1]