def visit_For(self, node): node.iter = self.visit(node.iter) node.body = self.visit_block(node.body) node.orelse = self.visit_block(node.orelse) targets = [] if isinstance(node.target, gast.Tuple) or isinstance( node.target, gast.List): for target in node.target.elts: targets.append(target) elif isinstance(node.target, gast.Name): targets.append(node.target) else: raise ValueError( 'For target must be gast.Tuple, gast.List, or gast.Name, got {}.' .format(type(node.target))) n_target = self.ctx.namer.new_symbol('n_target', set([target.id for target in targets])) target_assigns = [] if len(targets) > 1: for i, target in enumerate(targets): target_assign = self._make_target_assign(target, n_target, i, self.overload) target_assigns.extend(target_assign) else: target_assign = templates.replace( 'overload.assign(target, n_target)', overload=self.overload.symbol_name, target=targets[0], n_target=n_target) target_assigns.extend(target_assign) template = """ for n_target in iter: target_assigns body else: orelse """ node = templates.replace( template, n_target=n_target, iter=node.iter, target_assigns=target_assigns, body=node.body, orelse=node.orelse) return node
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = parsing.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEqual(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def visit_While(self, node): body_scope = anno.getanno(node, anno.Static.BODY_SCOPE) orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE) modified_in_cond = body_scope.modified | orelse_scope.modified node = self.generic_visit(node) if not hasattr(self.overload.module, 'while_stmt'): return node template = """ def test_name(): return test def body_name(): body def orelse_name(): orelse overload.while_stmt(test_name, body_name, orelse_name, (local_writes,)) """ node = templates.replace( template, overload=self.overload.symbol_name, test_name=self.ctx.namer.new_symbol('while_test', set()), test=node.test, body_name=self.ctx.namer.new_symbol('while_body', set()), body=node.body, orelse_name=self.ctx.namer.new_symbol('while_orelse', set()), orelse=node.orelse if node.orelse else gast.Pass(), local_writes=tuple(modified_in_cond)) return node
def create_assignment(self, target, expression): template = """ target = expression """ return templates.replace(template, target=target, expression=expression)
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parsing.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = parsing.ast_to_object(node) self.assertEqual(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def test_replace_expression_context(self): template = """ def test_fn(): foo """ node = templates.replace( template, foo=parsing.parse_expression('a + 2 * b / -c'))[0] self.assertIsInstance(node.body[0].left.ctx, gast.Load) self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
def _wrap_in_generator(func, source, namer, overload): """Wraps the source code in a generated function. Args: func: the original function source: the generated source code namer: naming.Namer, used for naming vars overload: config.VirtualizationConfig Returns: The generated function with a new closure variable. """ nonlocals = [] for var in six.get_function_code(func).co_freevars: # We must generate dummy vars so the generated function has the same closure # as the original function. free_template = 'var = None' nonlocal_node = templates.replace(free_template, var=var) nonlocals.extend(nonlocal_node) gen_fun_name = namer.new_symbol('gen_fun', set()) template = """ def gen_fun(overload): nonlocals program return f_name """ ret = templates.replace( template, gen_fun=gen_fun_name, nonlocals=nonlocals, overload=overload.symbol_name, program=source, f_name=func.__name__) converted_module, _ = parsing.ast_to_object(ret) outer_func = getattr(converted_module, gen_fun_name) return outer_func(overload.module)
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = parsing.ast_to_object(node) self.assertEqual((2, 3), result.test_fn(2, 3))
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parsing.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = parsing.ast_to_object(node) self.assertEqual(3, result.test_fn())
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = parsing.ast_to_object(node) self.assertEqual(7, result.test_fn(2))
def test_replace_variable(self): template = """ def test_fn(a): a += 1 a = 2 * a + 1 return b """ node = templates.replace(template, a='b')[0] result, _ = parsing.ast_to_object(node) self.assertEqual(7, result.test_fn(2))
def test_replace_tuple_context(self): template = """ def test_fn(foo): foo = 0 """ node = templates.replace(template, foo=parsing.parse_expression('(a, b)'))[0] self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
def test_replace_index(self): template = """ def test_fn(): foo = 0 """ node = templates.replace( template, foo=parsing.parse_expression('foo(a[b]).bar'))[0] function_call_arg = node.body[0].targets[0].value.args[0] self.assertIsInstance(function_call_arg.ctx, gast.Load) self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
def test_replace_attribute_context(self): template = """ def test_fn(foo): foo = 0 """ node = templates.replace(template, foo=parsing.parse_expression('a.b.c'))[0] self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
def test_replace_complex_context(self): template = """ def test_fn(): foo = 0 """ node = templates.replace( template, foo=parsing.parse_expression('bar(([a, b],)).baz'))[0] self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) function_call_arg = node.body[0].targets[0].value.args[0] self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
def visit_For(self, node): body_scope = anno.getanno(node, anno.Static.BODY_SCOPE) orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE) modified_in_cond = body_scope.modified | orelse_scope.modified node = self.generic_visit(node) if not hasattr(self.overload.module, 'for_stmt'): return node # TODO(jmd1011): Handle extra_test targets = [] if isinstance(node.target, gast.Tuple) or isinstance( node.target, gast.List): for target in node.target.elts: targets.append(target) elif isinstance(node.target, gast.Name): targets.append(node.target) else: raise ValueError( 'For target must be gast.Tuple, gast.List, or gast.Name, got {}.' .format(type(node.target))) target_inits = [ self._make_target_init(target, self.overload) for target in targets ] template = """ target_inits def body_name(): body def orelse_name(): orelse overload.for_stmt(target, iter_, body_name, orelse_name, (local_writes,)) """ node = templates.replace( template, target_inits=target_inits, target=node.target, body_name=self.ctx.namer.new_symbol('for_body', set()), body=node.body, orelse_name=self.ctx.namer.new_symbol('for_orelse', set()), orelse=node.orelse if node.orelse else gast.Pass(), overload=self.overload.symbol_name, iter_=node.iter, local_writes=tuple(modified_in_cond)) return node
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parsing.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = parsing.ast_to_object(node) self.assertEqual(15, result.test_fn())
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([gast.Name('a', None, None)], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = parsing.ast_to_object(node) self.assertEqual(3, result.test_fn(1))
def visit_Assign(self, node): # TODO(b/123943188): Handle multiple assignment node.value = self.visit(node.value) lhs = node.targets[0].id rhs = node.value if not self.scope.should_virtualize(lhs): return node node = templates.replace( 'overload.assign(lhs, rhs)', lhs=lhs, rhs=rhs, overload=self.overload.symbol_name) return node
def _make_target_init(self, target, overload): return templates.replace('target = overload.init(target_name)', target=target, target_name='"{}"'.format(target.id), overload=self.overload.symbol_name)
def visit_FunctionDef(self, node): assert id(node) in self.scopes self.scope = self.scopes[id(node)] arg_names = [arg.id for arg in node.args.args] n_arg_names = [ self.ctx.namer.new_symbol(arg, set(arg_names)) for arg in arg_names ] init_nodes = [] for var in self.scope.locals: init_template = 'lhs = overload.init(lhs_name)' init_node = templates.replace( init_template, lhs=var, lhs_name='"{}"'.format(var), overload=self.overload.symbol_name) init_nodes.extend(init_node) arg_nodes = [] for (arg, n_arg) in zip(arg_names, n_arg_names): arg_template = 'overload.assign(lhs, rhs)' arg_node = templates.replace( arg_template, lhs=arg, overload=self.overload.symbol_name, rhs=n_arg) arg_nodes.extend(arg_node) node.body = self.visit_block(node.body) if self.scope.parent and self.scope.parent.is_local(node.name): template = """ def new_fun_name(args): inits arg_nodes body overload.assign(fun_name, new_fun_name) """ node = templates.replace( template, new_fun_name=self.ctx.namer.new_symbol(node.name, set([node.name])), args=n_arg_names, arg_nodes=arg_nodes, inits=init_nodes, body=node.body, overload=self.overload.symbol_name, fun_name=node.name, ) else: template = """ def fun_name(args): inits arg_nodes body """ node = templates.replace( template, fun_name=node.name, args=n_arg_names, arg_nodes=arg_nodes, inits=init_nodes, body=node.body, ) self.scope = self.scope.parent return node
def _make_target_assign(self, target, n_target, i, overload): return templates.replace( 'overload.assign(target, n_target[{}])'.format(i), target=target, n_target=n_target, overload=self.overload.symbol_name)