def size_container_folding(value): """ Convert value to ast expression if size is not too big. Converter for sized container. """ if len(value) < MAX_LEN: if isinstance(value, list): return ast.List([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, tuple): return ast.Tuple([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, set): return ast.Set([to_ast(elt) for elt in value]) elif isinstance(value, dict): keys = [to_ast(elt) for elt in value.keys()] values = [to_ast(elt) for elt in value.values()] return ast.Dict(keys, values) elif isinstance(value, numpy.ndarray): return ast.Call(func=ast.Attribute( ast.Name(mangle('numpy'), ast.Load(), None), 'array', ast.Load()), args=[to_ast(value.tolist())], keywords=[]) else: raise ConversionError() else: raise ToNotEval()
def create_while_node(condition_name, body_name, loop_var_names): while_args = [] while_args.append( gast.Name(id=condition_name, ctx=gast.Param(), annotation=None, type_comment=None)) while_args.append( gast.Name(id=body_name, ctx=gast.Param(), annotation=None, type_comment=None)) assign_targets = [ gast.Name(id=var_name, ctx=gast.Param(), annotation=None, type_comment=None) for var_name in loop_var_names ] while_args.append(gast.List(elts=assign_targets, ctx=gast.Param())) while_func_id = gast.parse('fluid.layers.while_loop').body[0].value while_node = gast.Call(func=while_func_id, args=while_args, keywords=[]) assign_node = gast.Assign( targets=[gast.Tuple(elts=assign_targets, ctx=gast.Store())], value=while_node) return assign_node
def visit_ListComp(self, node): def makeattr(*args): r = ast.Attribute(value=ast.Name(id='builtins', ctx=ast.Load(), annotation=None, type_comment=None), attr='map', ctx=ast.Load()) r = ast.Call(r, list(args), []) r = ast.Call( ast.Attribute(ast.Name('builtins', ast.Load(), None, None), 'list', ast.Load()), [r], []) return r if isinstance(node.elt, ast.Constant) and len(node.generators) == 1: gen = node.generators[0] if not gen.ifs and isinstance(gen.iter, ast.Call): try: path = attr_to_path(gen.iter.func)[1] range_path = 'pythonic', 'builtins', 'functor', 'range' if path == range_path and len(gen.iter.args) == 1: self.update = True return ast.BinOp( ast.List([node.elt], ast.Load()), ast.Mult(), ast.Call(path_to_attr(('builtins', 'len')), [gen.iter], [])) except TypeError: pass return self.visitComp(node, makeattr)
def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace(template, call=node.func, wrapper=wrapper_name, original_args=gast.List( elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def visit_FunctionDef(self, node): """Intercepts function definitions. Converts function definitions to the corresponding `ProgramBuilder.function` construction. Args: node: An `ast.AST` node representing the function to convert. Returns: node: An updated node, representing the result. Raises: ValueError: If the input node does not adhere to the restrictions, e.g., failing to have a `return` statement at the end. """ # Check input form return_node = node.body[-1] if not isinstance(return_node, gast.Return): msg = 'Last node in function body should be Return, not {}.' raise ValueError(msg.format(return_node)) # Convert all args to _tfp_autobatching_context_.param() local_declarations = [] for arg in node.args.args: # print('Creating param declaration for', arg, arg.id, type(arg.id)) local_declarations.append(templates.replace( 'target = _tfp_autobatching_context_.param(name=target_name)', target=arg.id, target_name=gast_util.Str(arg.id))[0]) # Visit the content of the function node = self.generic_visit(node) # Prepend the declarations node.body = local_declarations + node.body # Convert the function into a # `with _tfp_autobatching_context_.define_function()` block. # Wrap the `with` block into a function so additional information (namely, # the auto-batching `ProgramBuilder` and the `instruction.Function`s that # may be called in the body) can be passed in through regular Python # variable references. callable_function_names = [ gast_util.Name(n, ctx=gast.Store(), annotation=None) for n in self.known_functions] node = templates.replace( ''' def func(_tfp_autobatching_context_, _tfp_autobatching_available_functions_): names = _tfp_autobatching_available_functions_ with _tfp_autobatching_context_.define_function(func): body return func''', func=node.name, names=gast.List(callable_function_names, ctx=gast.Store()), body=node.body)[0] return node
def visit_DictComp(self, node): # this is a quickfix to match visit_AnyComp signature # potential source of improvement there! node.elt = ast.List( [ast.Tuple([node.key, node.value], ast.Load())], ast.Load() ) return self.visit_AnyComp(node, "dict", "__dispatch__", "update")
def inlineBuiltinsXMap(self, node): self.update = True elts = [] nelts = min(len(n.elts) for n in node.args[1:]) for i in range(nelts): elts.append([n.elts[i] for n in node.args[1:]]) return ast.List([ast.Call(node.args[0], elt, []) for elt in elts], ast.Load())
def _assignment_construct_recur(self, target): if isinstance(target, (gast.Tuple, gast.List)): subs = [self._assignment_construct_recur(t) for t in target.elts] if isinstance(target, gast.Tuple): # Context is not Store anymore, because this section is constructing the # pattern object return gast.Tuple(subs, ctx=gast.Load()) else: # Context is not Store anymore, because this section is constructing the # pattern object return gast.List(subs, ctx=gast.Load()) return templates.replace_as_expression( '_tfp_autobatching_context_.var.name', name=target)
class TuplePattern(Pattern): # __builtin__.tuple([X, ..., Z]) => (X, ..., Z) pattern = ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None, type_comment=None), attr="tuple", ctx=ast.Load()), args=[ast.List(Placeholder(0), ast.Load())], keywords=[]) @staticmethod def sub(): return ast.Tuple(Placeholder(0), ast.Load())
def visit_Call(self, node): func_aliases = self.aliases.get(node.func, None) if func_aliases is not None: if func_aliases.issubset(patterns): if istuple(node.args[0]): self.update = True node.args[0] = toshape(node.args[0]) elif func_aliases.issubset(reshape_patterns): if len(node.args) > 2: self.update = True node.args[1:] = [ toshape(ast.List(node.args[1:], ast.Load())) ] return self.generic_visit(node)
def size_container_folding(value): """ Convert value to ast expression if size is not too big. Converter for sized container. """ def size(x): return len(getattr(x, 'flatten', lambda: x)()) if size(value) < MAX_LEN: if isinstance(value, list): return ast.List([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, tuple): return ast.Tuple([to_ast(elt) for elt in value], ast.Load()) elif isinstance(value, set): if value: return ast.Set([to_ast(elt) for elt in value]) else: return ast.Call(func=ast.Attribute( ast.Name(mangle('builtins'), ast.Load(), None, None), 'set', ast.Load()), args=[], keywords=[]) elif isinstance(value, dict): keys = [to_ast(elt) for elt in value.keys()] values = [to_ast(elt) for elt in value.values()] return ast.Dict(keys, values) elif isinstance(value, np.ndarray): if len(value) == 0: return ast.Call( func=ast.Attribute( ast.Name(mangle('numpy'), ast.Load(), None, None), 'empty', ast.Load()), args=[to_ast(value.shape), dtype_to_ast(value.dtype.name)], keywords=[]) else: return ast.Call(func=ast.Attribute( ast.Name(mangle('numpy'), ast.Load(), None, None), 'array', ast.Load()), args=[ to_ast(totuple(value.tolist())), dtype_to_ast(value.dtype.name) ], keywords=[]) else: raise ConversionError() else: raise ToNotEval()
def _to_reference_list(self, names): return gast.List([self._to_reference(name) for name in names], ctx=gast.Load())
lambda: ast.Call(func=ast.Attribute(value=ast.Attribute(value=ast.Name( id='__builtin__', ctx=ast.Load(), annotation=None), attr="pythran", ctx=ast.Load()), attr="abssqr", ctx=ast.Load()), args=[Placeholder(0)], keywords=[])), # __builtin__.tuple([X, ..., Z]) => (X, ..., Z) (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="tuple", ctx=ast.Load()), args=[ast.List(Placeholder(0), ast.Load())], keywords=[]), lambda: ast.Tuple(Placeholder(0), ast.Load())), # __builtin__.reversed(__builtin__.xrange(X)) => # __builtin__.xrange(X-1, -1, -1) # FIXME : We should do it even when begin/end/step are given (ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="reversed", ctx=ast.Load()), args=[ ast.Call(func=ast.Attribute(value=ast.Name(id='__builtin__', ctx=ast.Load(), annotation=None), attr="xrange",