def visit_Assign(self, node): if len(node.targets) > 1: raise NotImplementedError("cannot process multiple assignment") if not isinstance(node.targets[0], gast.Name): raise NotImplementedError("cannot process indexed assignment") # $lhs = $lhs.update_($rhs, matchbox.EXECUTION_MASK) if (lhs in vars() # or lhs in globals()) and isinstance($lhs, (matchbox.MaskedBatch, # matchbox.TENSOR_TYPE)) else $rhs node.value = gast.IfExp( gast.BoolOp( gast.And(), [ gast.BoolOp(gast.Or(), [ gast.Compare(gast.Str( node.targets[0].id), [gast.In()], [ gast.Call(gast.Name('vars', gast.Load, None), [], []) ]), gast.Compare(gast.Str( node.targets[0].id), [gast.In()], [ gast.Call( gast.Name('globals', gast.Load, None), [], []) ]) ]), # gast.Compare( # gast.Attribute( # gast.Name('matchbox', gast.Load(), None), # gast.Name('EXECUTION_MASK', gast.Load(), None), # gast.Load()), # [gast.IsNot()], # [gast.NameConstant(None)]), gast.Call(gast.Name('isinstance', gast.Load(), None), [ node.targets[0], gast.Tuple([ gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('MaskedBatch', gast.Load(), None), gast.Load()), gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('TENSOR_TYPE', gast.Load(), None), gast.Load()) ], gast.Load()) ], []) ]), gast.Call( gast.Attribute( gast.Name(node.targets[0].id, gast.Load(), None), gast.Name('_update', gast.Load(), None), gast.Load()), [ node.value, gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('EXECUTION_MASK', gast.Load(), None), gast.Load()) ], []), node.value) return node
def visit_Attribute(self, node): node = self.generic_visit(node) # method name -> not a getattr if node.attr in methods: return node # imported module -> not a getattr elif (isinstance(node.value, ast.Name) and node.value.id in self.imports): module_id = self.imports[node.value.id] if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]: msg = ("`" + node.attr + "' is not a member of " + module_id + " or Pythran does not support it") raise PythranSyntaxError(msg, node) node.value.id = module_id # patch module aliasing self.update = True return node # not listed as attributed -> not a getattr elif node.attr not in attributes: return node # A getattr ! else: self.update = True call = ast.Call( ast.Attribute(ast.Name('__builtin__', ast.Load(), None), 'getattr', ast.Load()), [node.value, ast.Str(node.attr)], []) if isinstance(node.ctx, ast.Store): # the only situation where this arises is for real/imag of # a ndarray. As a call is not valid for a store, add a slice # to ends up with a valid lhs assert node.attr in ('real', 'imag'), "only store to imag/real" return ast.Subscript(call, ast.Slice(None, None, None), node.ctx) else: return call
def visit_Lambda(self, node): self.state[_Function].enter() node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if self.state[_Function].level > 2: self.state[_Function].exit() return node scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self.ctx.program.options.to_ast(), function_context=function_context_name, function_context_name=gast.Str(function_context_name), body=node.body) self.state[_Function].exit() return node
def to_ast(value): """ Turn a value into ast expression. >>> a = 1 >>> print ast.dump(to_ast(a)) Num(n=1) >>> a = [1, 2, 3] >>> print ast.dump(to_ast(a)) List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load()) """ if isinstance(value, (type(None), bool)): return builtin_folding(value) if any(value is t for t in (bool, int, float)): return builtin_folding(value) elif isinstance(value, numpy.generic): return to_ast(numpy.asscalar(value)) elif isinstance(value, numbers.Number): return ast.Num(value) elif isinstance(value, str): return ast.Str(value) elif isinstance(value, (list, tuple, set, dict, numpy.ndarray)): return size_container_folding(value) elif hasattr(value, "__module__") and value.__module__ == "__builtin__": # TODO Can be done the same way for others modules return builtin_folding(value) # only meaningful for python3 elif sys.version_info.major == 3: if isinstance(value, (filter, map, zip)): return to_ast(list(value)) raise ToNotEval()
def visit_Attribute(self, node): node = self.generic_visit(node) # storing in an attribute -> not a getattr if not isinstance(node.ctx, ast.Load): return node # method name -> not a getattr elif node.attr in methods: return node # imported module -> not a getattr elif (isinstance(node.value, ast.Name) and node.value.id in self.imports): module_id = self.imports[node.value.id] if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]: msg = ("`" + node.attr + "' is not a member of " + module_id + " or Pythran does not support it") raise PythranSyntaxError(msg, node) node.value.id = module_id # patch module aliasing self.update = True return node # not listed as attributed -> not a getattr elif node.attr not in attributes: return node # A getattr ! else: self.update = True return ast.Call( ast.Attribute(ast.Name('__builtin__', ast.Load(), None), 'getattr', ast.Load()), [node.value, ast.Str(node.attr)], [])
def keywords_to_dict(keywords): keys = [] values = [] for kw in keywords: keys.append(gast.Str(kw.arg)) values.append(kw.value) return gast.Dict(keys=keys, values=values)
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) template = """ ag__.converted_call(func, owner, options, args) """ if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def keywords_to_dict(keywords): """Converts a list of ast.keyword objects to a dict.""" keys = [] values = [] for kw in keywords: keys.append(gast.Str(kw.arg)) values.append(kw.value) return gast.Dict(keys=keys, values=values)
def visit_FunctionDef(self, node): """Intercepts function definitions. Converts function definitions to the corresponding `ProgramBuilder.function` construction. Args: node: An `ast.AST` node representing the function to convert. Returns: node: An updated node, representing the result. Raises: ValueError: If the input node does not adhere to the restrictions, e.g., failing to have a `return` statement at the end. """ # Check input form return_node = node.body[-1] if not isinstance(return_node, gast.Return): msg = 'Last node in function body should be Return, not {}.' raise ValueError(msg.format(return_node)) # Convert all args to _tfp_autobatching_context_.param() local_declarations = [] for arg in node.args.args: # print('Creating param declaration for', arg, arg.id, type(arg.id)) local_declarations.append(templates.replace( 'target = _tfp_autobatching_context_.param(name=target_name)', target=arg.id, target_name=gast.Str(arg.id))[0]) # Visit the content of the function node = self.generic_visit(node) # Prepend the declarations node.body = local_declarations + node.body # Convert the function into a # `with _tfp_autobatching_context_.define_function()` block. # Wrap the `with` block into a function so additional information (namely, # the auto-batching `ProgramBuilder` and the `instruction.Function`s that # may be called in the body) can be passed in through regular Python # variable references. callable_function_names = [gast.Name(n, ctx=gast.Store(), annotation=None) for n in self.known_functions] node = templates.replace( ''' def func(_tfp_autobatching_context_, _tfp_autobatching_available_functions_): names = _tfp_autobatching_available_functions_ with _tfp_autobatching_context_.define_function(func): body return func''', func=node.name, names=gast.List(callable_function_names, ctx=gast.Store()), body=node.body)[0] return node
def _create_undefined_assigns(self, undefined_symbols): assignments = [] for s in undefined_symbols: template = ''' var = ag__.Undefined(symbol_name) ''' assignments += templates.replace(template, var=s, symbol_name=gast.Str(s.ssf())) return assignments
def to_ast(value): """ Turn a value into ast expression. >>> a = 1 >>> print ast.dump(to_ast(a)) Num(n=1) >>> a = [1, 2, 3] >>> print ast.dump(to_ast(a)) List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load()) """ numpy_type = (numpy.float64, numpy.float32, numpy.float16, numpy.complex_, numpy.complex64, numpy.complex128, numpy.float_, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.intp, numpy.intc, numpy.int_, numpy.bool_) itertools_t = [ getattr(itertools, fun) for fun in dir(itertools) if isinstance(getattr(itertools, fun), type) ] unfolded_type = (types.BuiltinFunctionType, types.BuiltinMethodType, numpy.ufunc, type(list.append), BaseException, types.GeneratorType) + tuple(itertools_t) if sys.version_info.major == 2: unfolded_type += (types.FunctionType, types.FileType, types.TypeType, types.XRangeType) else: unfolded_type += type, range, type(numpy.array2string) if isinstance(value, (type(None), bool)): return ast.Attribute(ast.Name('__builtin__', ast.Load(), None), str(value), ast.Load()) elif isinstance(value, numpy_type): return to_ast(numpy.asscalar(value)) elif isinstance(value, (int, long, float, complex)): return ast.Num(value) elif isinstance(value, str): return ast.Str(value) elif isinstance(value, (list, tuple, set, dict, numpy.ndarray)): return size_container_folding(value) elif hasattr(value, "__module__") and value.__module__ == "__builtin__": # TODO Can be done the same way for others modules return builtin_folding(value) elif isinstance(value, unfolded_type): raise ToNotEval() elif value in numpy_type: raise ToNotEval() # only meaningful for python3 elif isinstance(value, (filter, map, zip)): return to_ast(list(value)) else: raise ConversionError()
def _create_loop_options(self, node): if not anno.hasanno(node, anno.Basic.DIRECTIVES): return gast.Dict([], []) loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) if directives.set_loop_options not in loop_directives: return gast.Dict([], []) opts_dict = loop_directives[directives.set_loop_options] str_keys, values = zip(*opts_dict.items()) keys = [gast.Str(s) for s in str_keys] values = list(values) # ast and gast don't play well with tuples. return gast.Dict(keys, values)
def visit_FunctionDef(self, node): self.state[_Function].enter() scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) function_context_name = self.ctx.namer.new_symbol( 'fscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) node = self.generic_visit(node) docstring_node = None if node.body: first_statement = node.body[0] if (isinstance(first_statement, gast.Expr) and isinstance(first_statement.value, gast.Str)): docstring_node = first_statement node.body = node.body[1:] template = """ with ag__.FunctionScope( function_name, context_name, options) as function_context: body """ wrapped_body = templates.replace( template, function_name=gast.Str(node.name), context_name=gast.Str(function_context_name), options=self.ctx.program.options.to_ast(), function_context=function_context_name, body=node.body) if docstring_node is not None: wrapped_body = [docstring_node] + wrapped_body node.body = wrapped_body self.state[_Function].exit() return node
def visit_FunctionDef(self, node): self._function_level += 1 try: self.generic_visit(node) finally: self._function_level -= 1 scope_name = node.name if self._function_level == 0 and self.context.owner_type is not None: scope_name = '{}/{}'.format(self.context.owner_type.__name__, scope_name) node.body = templates.replace('with tf.name_scope(scope_name): body', scope_name=gast.Str(scope_name), body=node.body) return node
def visit_FunctionDef(self, node): self.state[_Function].enter() scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) function_context_name = self.ctx.namer.new_symbol( 'fn_context', scope.referenced) self.state[_Function].context_name = function_context_name node = self.generic_visit(node) docstring_node = None if node.body: first_statement = node.body[0] if (isinstance(first_statement, gast.Expr) and isinstance(first_statement.value, gast.Str)): docstring_node = first_statement node.body = node.body[1:] if self.ctx.program.options.uses(converter.Feature.NAME_SCOPES): use_name_scopes = parser.parse_expression('True') scope_name = gast.Str(self._sanitize(node.name)) else: use_name_scopes = parser.parse_expression('False') scope_name = parser.parse_expression('None') use_auto_deps = parser.parse_expression( str( self.ctx.program.options.uses( converter.Feature.AUTO_CONTROL_DEPS))) template = """ with ag__.FunctionScope( use_name_scopes, scope_name, use_auto_deps) as function_context_name: body """ wrapped_body = templates.replace( template, use_name_scopes=use_name_scopes, scope_name=scope_name, use_auto_deps=use_auto_deps, function_context_name=function_context_name, body=node.body) if docstring_node is not None: wrapped_body = [docstring_node] + wrapped_body node.body = wrapped_body self.state[_Function].exit() return node
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ ag__.assert_stmt(test, lambda: msg) """ if node.msg is None: return templates.replace( template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError('can only convert string messages for now.')
def create_grad(node, namer, tangent=False): """Given a variable, create a variable for the gradient. Args: node: A node to create a gradient for, can be a normal variable (`x`) or a subscript (`x[i]`). namer: The namer object which will determine the name to use for the gradient. tangent: Whether a tangent (instead of adjoint) is created. Returns: node: A node representing the gradient with the correct name e.g. the gradient of `x[i]` is `dx[i]`. Note that this returns an invalid node, with the `ctx` attribute missing. It is assumed that this attribute is filled in later. Node has an `adjoint_var` annotation referring to the node it is an adjoint of. """ if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)): raise TypeError if anno.hasanno(node, 'temp_var'): return create_grad(anno.getanno(node, 'temp_var'), namer, tangent) def _name_grad(node): if not isinstance(node, gast.Name): raise TypeError varname = node.id name = namer.grad(varname, tangent) grad_node = gast.Name(id=name, ctx=None, annotation=None) anno.setanno(grad_node, 'adjoint_var', node) return grad_node if isinstance(node, gast.Subscript): grad_node = create_grad(node.value, namer, tangent=tangent) grad_node.ctx = gast.Load() return gast.Subscript(value=grad_node, slice=node.slice, ctx=None) elif isinstance(node, gast.Str): grad_node = create_grad(gast.Name(id=node.s, ctx=None, annotation=None), namer, tangent=tangent) return gast.Str(grad_node.id) else: return _name_grad(node)
def ast(self): # The caller must adjust the context appropriately. if self.has_subscript(): return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()), None) if self.has_attr(): return gast.Attribute(self.parent.ast(), self.qn[-1], None) base = self.qn[0] if isinstance(base, str): return gast.Name(base, None, None) elif isinstance(base, StringLiteral): return gast.Str(base.value) elif isinstance(base, NumberLiteral): return gast.Num(base.value) else: assert False, ('the constructor should prevent types other than ' 'str, StringLiteral and NumberLiteral')
def test_ast_to_source(self): node = gast.If( test=gast.Num(1), body=[ gast.Assign(targets=[gast.Name('a', gast.Store(), None)], value=gast.Name('b', gast.Load(), None)) ], orelse=[ gast.Assign(targets=[gast.Name('a', gast.Store(), None)], value=gast.Str('c')) ]) source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' """).strip(), source.strip())
def dispatch(self, tree): """Dispatcher function, dispatching tree type T to method _T.""" # display omp directive in python dump for omp in metadata.get(tree, openmp.OMPDirective): deps = list() for dep in omp.deps: old_file = self.f self.f = io.StringIO() self.dispatch(dep) deps.append(self.f.getvalue()) self.f = old_file directive = omp.s.format(*deps) self._Expr(ast.Expr(ast.Str(s=directive))) if isinstance(tree, list): for t in tree: self.dispatch(t) return meth = getattr(self, "_" + tree.__class__.__name__) meth(tree)
def to_ast(value): """ Turn a value into ast expression. >>> a = 1 >>> print(ast.dump(to_ast(a))) Num(n=1) >>> a = [1, 2, 3] >>> print(ast.dump(to_ast(a))) List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load()) """ if isinstance(value, (type(None), bool)): return builtin_folding(value) if sys.version_info[0] == 2 and isinstance(value, long): from pythran.syntax import PythranSyntaxError raise PythranSyntaxError("constant folding results in big int") if any(value is t for t in (bool, int, float)): iinfo = np.iinfo(int) if isinstance(value, int) and not (iinfo.min <= value <= iinfo.max): from pythran.syntax import PythranSyntaxError raise PythranSyntaxError("constant folding results in big int") return builtin_folding(value) elif isinstance(value, np.generic): return to_ast(np.asscalar(value)) elif isinstance(value, numbers.Number): return ast.Num(value) elif isinstance(value, str): return ast.Str(value) elif isinstance(value, (list, tuple, set, dict, np.ndarray)): return size_container_folding(value) elif hasattr(value, "__module__") and value.__module__ == "__builtin__": # TODO Can be done the same way for others modules return builtin_folding(value) # only meaningful for python3 elif sys.version_info.major == 3: if isinstance(value, (filter, map, zip)): return to_ast(list(value)) raise ToNotEval()
def visit_FunctionDef(self, node): node = self.generic_visit(node) unscoped_body = [] scoped_body = node.body if scoped_body: first = scoped_body[0] if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str): # Skip any docstring. unscoped_body = scoped_body[:1] scoped_body = scoped_body[1:] template = """ with tf.name_scope(scope_name): body """ scoped_body = templates.replace( template, scope_name=gast.Str(self._name_for_current_scope()), body=scoped_body) node.body = unscoped_body + scoped_body return node
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ ag__.converted_call(func, owner, options, args) """ if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx.info.namespace, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def visit_FunctionDef(self, node): node = self.generic_visit(node) final_body = [] indented_body = node.body if node.body: first_statement = node.body[0] # Skip the docstring, if any. if (isinstance(first_statement, gast.Expr) and isinstance(first_statement.value, gast.Str)): indented_body = indented_body[1:] final_body.append(first_statement) template = """ with ag__.function_scope(scope_name): body """ scoped_body = templates.replace(template, scope_name=gast.Str( self._name_for_current_scope()), body=indented_body) final_body.extend(scoped_body) node.body = final_body return node
def visit_While(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Str(str(symbol)) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ state_functions def body_name(loop_vars): body return loop_vars, def test_name(loop_vars): return test loop_vars_ast_tuple = ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ node = templates.replace( template, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ state_functions def body_name(): body return () def test_name(): return test ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') starred_arg = None normal_args = [] for a in node.args: if isinstance(a, gast.Starred): assert starred_arg is None, 'Multiple *args should be impossible.' starred_arg = a else: normal_args.append(a) if starred_arg is None: args = templates.replace_as_expression('(args,)', args=normal_args) else: args = templates.replace_as_expression( '(args,) + tuple(stararg)', stararg=starred_arg.value, args=normal_args) kwargs_arg = None normal_keywords = [] for k in node.keywords: if k.arg is None: assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' kwargs_arg = k else: normal_keywords.append(k) if kwargs_arg is None: kwargs = ast_util.keywords_to_dict(normal_keywords) else: kwargs = templates.replace_as_expression( 'dict(kwargs, **keywords)', kwargs=kwargs_arg.value, keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ ag__.converted_call(func, owner, options, args, kwargs) """ new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=args, kwargs=kwargs) return new_call
ast.BinOp(left=Placeholder(0), op=ast.Sub(), right=ast.Num(n=1)), ast.Num(n=-1), ast.Num(n=-1) ], keywords=[])), # X * X => X ** 2 (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)), lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))), # a + "..." + b => "...".join((a, b)) (ast.BinOp(left=ast.BinOp(left=Placeholder(0), op=ast.Add(), right=ast.Str(Placeholder(1))), op=ast.Add(), right=Placeholder(2)), lambda: ast.Call(func=ast.Attribute( ast.Attribute(ast.Name('__builtin__', ast.Load(), None), 'str', ast.Load()), 'join', ast.Load()), args=[ ast.Str(Placeholder(1)), ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load()) ], keywords=[])), ] class PlaceholderReplace(Transformation):
def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() composites = set() for s in modified_in_cond: if s in live_out and not s.is_composite(): returned_from_cond.add(s) if s.is_composite(): # Special treatment for compound objects, always return them. # This allows special handling within the if_stmt itself. # For example, in TensorFlow we need to restore the state of composite # symbols to ensure that only effects from the executed branch are seen. composites.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple(s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple(s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) all_referenced = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) returned_from_cond = tuple(returned_from_cond) composites = tuple(composites) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) composite_defs = self._create_state_functions(composites, state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Str(str(symbol)) for symbol in returned_from_cond) composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composites) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name, state_getter_name, state_setter_name, basic_symbol_names, composite_symbol_names) if_ast = (undefined_assigns + composite_defs + body_def + orelse_def + cond_assign + cond_expr) return if_ast
def visit_For(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(loop_vars): return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_name=extra_test_name, loop_vars=loop_vars, extra_test_expr=extra_test) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # Workaround for PEP-3113 # iterates_var holds a single variable with the iterates, which may be a # tuple. iterates_var_name = self.ctx.namer.new_symbol('iterates', reserved_symbols) template = """ iterates = iterates_var_name """ iterate_expansion = templates.replace( template, iterates=node.target, iterates_var_name=iterates_var_name) undefined_assigns = self._create_undefined_assigns(possibly_undefs) basic_symbol_names = tuple( gast.Str(str(symbol)) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ undefined_assigns state_functions def body_name(iterates_var_name, loop_vars): iterate_expansion body return loop_vars, extra_test_function loop_vars_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ undefined_assigns state_functions def body_name(iterates_var_name): iterate_expansion body return () extra_test_function ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts)
lambda: ast.Call( func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="xrange", ctx=ast.Load()), args=[ast.BinOp(left=Placeholder(0), op=ast.Sub(), right=ast.Num(n=1)), ast.Num(n=-1), ast.Num(n=-1)], keywords=[])), # X * X => X ** 2 (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)), lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))), # a + "..." + b => "...".join((a, b)) (ast.BinOp(left=ast.BinOp(left=Placeholder(0), op=ast.Add(), right=ast.Str(Placeholder(1))), op=ast.Add(), right=Placeholder(2)), lambda: ast.Call(func=ast.Attribute( ast.Attribute( ast.Name('__builtin__', ast.Load(), None), 'str', ast.Load()), 'join', ast.Load()), args=[ast.Str(Placeholder(1)), ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())], keywords=[])), ] class PlaceholderReplace(Transformation):