def setUp(self): self.source = """ def test_fn(x, y): a = 1 x = y + a if x > y: z = x * x z = z + a else: z = y * y return z """ self.all_name_ids = { 'x': [ gast.Param(), gast.Store(), gast.Load(), gast.Load(), gast.Load() ], 'a': [gast.Store(), gast.Load(), gast.Load()], 'y': [ gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load(), ], 'z': [ gast.Store(), gast.Load(), gast.Store(), gast.Store(), gast.Load(), ] }
def setUp(self): self.source = """ def test_fn(x, y): z = 1 if x > y: z = x * x z = z + y return z """ self.all_name_ids = { 'x': [ gast.Param(), gast.Load(), gast.Load(), gast.Load(), ], 'y': [ gast.Param(), gast.Load(), gast.Load(), ], 'z': [ gast.Store(), gast.Store(), gast.Load(), gast.Store(), gast.Load(), ] }
def test_unparse(self): node = gast.If(test=gast.Constant(1, kind=None), body=[ gast.Assign(targets=[ gast.Name('a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name('b', ctx=gast.Load(), annotation=None, type_comment=None)) ], orelse=[ gast.Assign(targets=[ gast.Name('a', ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant('c', kind=None)) ]) source = parser.unparse(node, indentation=' ') self.assertEqual( textwrap.dedent(""" # coding=utf-8 if 1: a = b else: a = 'c' """).strip(), source.strip())
def trivialize_slice(self, node): if isinstance(node, gast.Slice): name = self.namer.name(node) target = gast.Name(id=name, ctx=gast.Store(), annotation=None) stmt = gast.Assign(targets=[target], value=None) self.prepend(stmt) stmt.value = gast.Call( func=gast.Name(id='slice', ctx=gast.Load(), annotation=None), args=[ self.trivialize(arg) if arg else gast.Name( id='None', ctx=gast.Load(), annotation=None) for arg in [node.lower, node.upper, node.step] ], keywords=[]) return gast.Name(id=name, ctx=gast.Load(), annotation=None) elif isinstance(node, gast.ExtSlice): name = self.namer.name(node) target = gast.Name(id=name, ctx=gast.Store(), annotation=None) stmt = gast.Assign(targets=[target], value=None) self.prepend(stmt) dim_names = [self.trivialize_slice(s).id for s in node.dims] stmt.value = gast.Tuple(elts=[ gast.Name(id=n, ctx=gast.Load(), annotation=None) for n in dim_names ], ctx=gast.Load()) return gast.Name(id=name, ctx=gast.Load(), annotation=None) elif isinstance(node, gast.Index): return self.trivialize(node.value) else: raise ValueError(node)
def visit_Return(self, node): modified_node = self.generic_visit(node) if node.value is None: node_value = gast.Constant(value=None, kind=None) else: node_value = node.value self.func_returned_stack[-1] = True returned_id = len(self.func_returned_stack) replacement = [ gast.Assign(targets=[ gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=True, kind=None)), gast.Assign(targets=[ gast.Name(id=self.returned_value_key, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node_value) ] if isinstance(modified_node, gast.If): #TODO: Add location to returned value. modified_node.body = replacement return modified_node else: return replacement
def visit_FunctionDef(self, node): modified_node = self.generic_visit(node) returned_id = len(self.func_returned_stack) returned_flags = self.func_returned_stack.pop() if returned_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.returned_value_key, ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=None, kind=None))) node.body.append( gast.Return(value=gast.Name(id=self.returned_value_key, ctx=gast.Load(), annotation=None, type_comment=None))) return modified_node
def visit_For(self, node): modified_node = self.generic_visit(node) continued_id = len(self.for_continued_stack) continued_flags = self.for_continued_stack.pop() if continued_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.continued_flag + str(continued_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) breaked_id = len(self.for_breaked_stack) breaked_flags = self.for_breaked_stack.pop() bool_values = [] if breaked_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) bool_values.append( gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Load(), annotation=None, type_comment=None)) if len(self.func_returned_stack) > 0: returned_id = len(self.func_returned_stack) returned_flags = self.func_returned_stack[-1] if returned_flags: bool_values.append( gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Load(), annotation=None, type_comment=None)) if len(bool_values) > 0: if len(bool_values) == 1: cond = bool_values[0] elif len(bool_values) > 1: cond = gast.BoolOp(op=gast.Or(), values=bool_values) node.body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) node.body.append(gast.If(test=cond, body=[gast.Break()], orelse=[])) return modified_node
def visit_Assign(self, node): self.generic_visit(node) # if the rhs is an identifier, we don't need to duplicate it # otherwise, better duplicate it... no_tmp = isinstance(node.value, (ast.Name, ast.Attribute)) extra_assign = [] if no_tmp else [node] for i, t in enumerate(node.targets): if isinstance(t, ast.Tuple) or isinstance(t, ast.List): renamings = OrderedDict() self.traverse_tuples(t, (), renamings) if renamings: if no_tmp: gstore = deepcopy(node.value) else: gstore = ast.Name(self.get_new_id(), ast.Store(), None, None) gload = deepcopy(gstore) gload.ctx = ast.Load() node.targets[i] = gstore for rename, state in renamings.items(): nnode = reduce( lambda x, y: ast.Subscript(x, ast.Constant( y, None), ast.Load()), state, gload) if isinstance(rename, str): extra_assign.append( ast.Assign([ ast.Name(rename, ast.Store(), None, None) ], nnode, None)) else: extra_assign.append( ast.Assign([rename], nnode, None)) return extra_assign or node
def visit_FunctionDef(self, node): """Intercepts function definitions. Converts function definitions to the corresponding `ProgramBuilder.function` construction. Args: node: An `ast.AST` node representing the function to convert. Returns: node: An updated node, representing the result. Raises: ValueError: If the input node does not adhere to the restrictions, e.g., failing to have a `return` statement at the end. """ # Check input form return_node = node.body[-1] if not isinstance(return_node, gast.Return): msg = 'Last node in function body should be Return, not {}.' raise ValueError(msg.format(return_node)) # Convert all args to _tfp_autobatching_context_.param() local_declarations = [] for arg in node.args.args: # print('Creating param declaration for', arg, arg.id, type(arg.id)) local_declarations.append(templates.replace( 'target = _tfp_autobatching_context_.param(name=target_name)', target=arg.id, target_name=gast_util.Str(arg.id))[0]) # Visit the content of the function node = self.generic_visit(node) # Prepend the declarations node.body = local_declarations + node.body # Convert the function into a # `with _tfp_autobatching_context_.define_function()` block. # Wrap the `with` block into a function so additional information (namely, # the auto-batching `ProgramBuilder` and the `instruction.Function`s that # may be called in the body) can be passed in through regular Python # variable references. callable_function_names = [ gast_util.Name(n, ctx=gast.Store(), annotation=None) for n in self.known_functions] node = templates.replace( ''' def func(_tfp_autobatching_context_, _tfp_autobatching_available_functions_): names = _tfp_autobatching_available_functions_ with _tfp_autobatching_context_.define_function(func): body return func''', func=node.name, names=gast.List(callable_function_names, ctx=gast.Store()), body=node.body)[0] return node
def visit_Compare(self, node): node = self.generic_visit(node) if len(node.ops) > 1: # in case we have more than one compare operator # we generate an auxiliary function # that lazily evaluates the needed parameters imported_ids = self.passmanager.gather(ImportedIds, node, self.ctx) imported_ids = sorted(imported_ids) binded_args = [ast.Name(i, ast.Load(), None) for i in imported_ids] # name of the new function forged_name = "{0}_compare{1}".format(self.prefix, len(self.compare_functions)) # call site call = ast.Call(ast.Name(forged_name, ast.Load(), None), binded_args, []) # new function arg_names = [ast.Name(i, ast.Param(), None) for i in imported_ids] args = ast.arguments(arg_names, None, [], [], None, []) body = [] # iteratively fill the body (yeah, feel your body!) if is_trivially_copied(node.left): prev_holder = node.left else: body.append( ast.Assign([ast.Name('$0', ast.Store(), None)], node.left)) prev_holder = ast.Name('$0', ast.Load(), None) for i, exp in enumerate(node.comparators): if is_trivially_copied(exp): holder = exp else: body.append( ast.Assign( [ast.Name('${}'.format(i + 1), ast.Store(), None)], exp)) holder = ast.Name('${}'.format(i + 1), ast.Load(), None) cond = ast.Compare(prev_holder, [node.ops[i]], [holder]) body.append( ast.If( cond, [ast.Pass()], [ast.Return(path_to_attr(('__builtin__', 'False')))])) prev_holder = holder body.append(ast.Return(path_to_attr(('__builtin__', 'True')))) forged_fdef = ast.FunctionDef(forged_name, args, body, [], None) self.compare_functions.append(forged_fdef) return call else: return node
def get_for_args_stmts(self, iter_name, args_list): ''' Returns 3 gast stmt nodes for argument. 1. Initailize of iterate variable 2. Condition for the loop 3. Statement for changing of iterate variable during the loop NOTE(TODO): Python allows to access iteration variable after loop, such as "for i in range(10)" will create i = 9 after the loop. But using current conversion will make i = 10. We should find a way to change it ''' len_range_args = len(args_list) assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" if len_range_args == 1: init_stmt = get_constant_variable_node(iter_name, 0) else: init_stmt = gast.Assign( targets=[ gast.Name( id=iter_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=args_list[0]) range_max_node = args_list[0] if len_range_args == 1 else args_list[1] step_node = args_list[2] if len_range_args == 3 else gast.Constant( value=1, kind=None) cond_stmt = gast.Compare( left=gast.BinOp( left=gast.Name( id=iter_name, ctx=gast.Load(), annotation=None, type_comment=None), op=gast.Add(), right=step_node), ops=[gast.LtE()], comparators=[range_max_node]) change_stmt = gast.AugAssign( target=gast.Name( id=iter_name, ctx=gast.Store(), annotation=None, type_comment=None), op=gast.Add(), value=step_node) return init_stmt, cond_stmt, change_stmt
def visit_For(self, node): modified_node = self.generic_visit(node) continue_flags = self.for_continue_stack.pop() for flag in continue_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=flag, ctx=gast.Store(), annotation=None) ], value=gast.NameConstant(value=False))) breaked_flags = self.for_breaked_stack.pop() bool_values = [] for flag in breaked_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=flag, ctx=gast.Store(), annotation=None) ], value=gast.NameConstant(value=False))) bool_values.append( gast.Name(id=flag, ctx=gast.Load(), annotation=None)) if len(bool_values) > 0: if len(bool_values) == 1: cond = bool_values[0] elif len(bool_values) > 1: cond = gast.BoolOp(op=gast.Or(), values=bool_values) if isinstance(modified_node, gast.For): modified_node.body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) modified_node.body.append( gast.If(test=cond, body=[gast.Break()], orelse=[])) elif isinstance(modified_node, gast.If): if isinstance(modified_node.body[0], gast.For): modified_node.body[0].body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) modified_node.body[0].body.append( gast.If(test=cond, body=[gast.Break()], orelse=[])) return modified_node
def visit_Assign(self, node): self.generic_visit(node) # if the rhs is an identifier, we don't need to duplicate it # otherwise, better duplicate it... no_tmp = isinstance(node.value, ast.Name) extra_assign = [] if no_tmp else [node] for i, t in enumerate(node.targets): if isinstance(t, ast.Tuple) or isinstance(t, ast.List): renamings = OrderedDict() self.traverse_tuples(t, (), renamings) if renamings: gtarget = node.value.id if no_tmp else self.get_new_id() node.targets[i] = ast.Name(gtarget, node.targets[i].ctx, None) for rename, state in renamings.items(): nnode = reduce( lambda x, y: ast.Subscript( x, ast.Index(ast.Num(y)), ast.Load()), state, ast.Name(gtarget, ast.Load(), None)) if isinstance(rename, str): extra_assign.append( ast.Assign( [ast.Name(rename, ast.Store(), None)], nnode)) else: extra_assign.append(ast.Assign([rename], nnode)) return extra_assign or node
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses # regex matching to adjust the exact location around the line number that # CPython records. This is particularly problematic for lambda functions, # where the entire containing lines are returned. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\n. This is an extremely rare occurrence. Please report it to' ' the TensorFlow team.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def synchronize_lcds(self, node): node = FuseAttributes().visit(node) loads, lcds = defaultdict(list), set() for child in node.body: for n in gast.walk(child): if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load): loads[n.id].append(n) if isinstance(child, gast.Assign): name = child.targets[0].id if name in loads: if name in lcds: raise NotImplementedError("cannot process LCD " "stored to twice") lcds.add(name) node = SplitAttributes().visit(node) synchronizes = [] for name in lcds: synchronize = gast.Assign( [gast.Name(name, gast.Store(), None)], gast.Call( gast.Attribute( gast.Name(name, gast.Load(), None), gast.Name('_synchronize', gast.Load(), None), None), [], [])) synchronizes.append(synchronize) node.body.extend(synchronizes) return node
def get_annotations(object_def, namespace): """Create the annotations from a definition node""" # print_dump(object_def) ast_annotations = ast.Assign( targets=[extast.Name("annotations", ast.Store())], value=ast.Dict(keys=[], values=[]), type_comment=None, ) if isinstance(object_def, ast.FunctionDef): _fill_ast_annotations_function(object_def, ast_annotations) elif isinstance(object_def, ast.ClassDef): _fill_ast_annotations_class(object_def, ast_annotations) else: raise NotImplementedError # print_dump(ast_annotations) source = extast.unparse(ast_annotations) try: del namespace["__builtins__"] except KeyError: pass exec(source, namespace) return namespace["annotations"]
def _replace_after_node_to_if_in_stmt_list( self, stmt_list, node, return_name, parent_node_of_return): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [if_stmt] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = False' node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_false_node = gast.parse(node_str).body[0] stmt_list[i:i] = [assign_false_node] return True
def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for t in targets: if isinstance(t, gast.Tuple): for i, e in enumerate(t.elts): self.scope.setval( anno.getanno(e, anno.Basic.QN), gast.Subscript(source, gast.Index(i), ctx=gast.Store())) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: raise ValueError('Dont know how to handle assignment to %s' % t)
def visit_Assign(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): target = node.value.func if anno.hasanno(target, 'live_val'): target_obj = anno.getanno(target, 'live_val') if tf_inspect.isclass(target_obj): # This is then a constructor. anno.setanno(node.value, 'type', target_obj) anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn')) # TODO (mdan): Raise an error if constructor has side effects. id:2153 gh:2154 # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. for n in node.targets: if isinstance(n, gast.Tuple): for i, e in enumerate(n.elts): self.scope.setval( e.id, gast.Subscript(node.value, gast.Index(i), ctx=gast.Store())) else: self.scope.setval(n.id, node.value) return node
def visit_Call(self, node): """ Replace function call by inlined function's body. We can inline if it aliases on only one function. """ func_aliases = self.aliases[node.func] if len(func_aliases) == 1: function_def = next(iter(func_aliases)) if (isinstance(function_def, ast.FunctionDef) and function_def.name in self.inlinable): self.update = True to_inline = copy.deepcopy(self.inlinable[function_def.name]) arg_to_value = dict() values = node.args values += to_inline.args.defaults[len(node.args) - len(to_inline.args.args):] for arg_fun, arg_call in zip(to_inline.args.args, values): v_name = "__pythran_inline{}{}{}".format( function_def.name, arg_fun.id, self.call_count) new_var = ast.Name(id=v_name, ctx=ast.Store(), annotation=None, type_comment=None) self.defs.append( ast.Assign(targets=[new_var], value=arg_call)) arg_to_value[arg_fun.id] = ast.Name(id=v_name, ctx=ast.Load(), annotation=None, type_comment=None) self.call_count += 1 return Inliner(arg_to_value).visit(to_inline.body[0]) return node
def visit_loop(self, node, update_mask=gast.NameConstant(value=None)): node = FuseAttributes().visit(node) loads, stores = defaultdict(list), set() for child in node.body: for n in gast.walk(child): if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load): loads[n.id].append(n) if isinstance(child, gast.Assign): if len(child.targets) > 1: raise NotImplementedError("cannot process LCD that is " "part of multiple assignment") name = child.targets[0].id if name in loads: if name in stores: raise NotImplementedError("cannot process LCD " "stored to twice") # $var = $expr -> $var = $var._update($expr) child.value = gast.Call( gast.Attribute(gast.Name(name, gast.Load(), None), gast.Name('_update', gast.Load(), None), None), [child.value, update_mask], []) stores.add(name) node = SplitAttributes().visit(node) synchronizes = [] for name in stores: synchronize = gast.Assign( [gast.Name(name, gast.Store(), None)], gast.Call( gast.Attribute( gast.Name(name, gast.Load(), None), gast.Name('_synchronize', gast.Load(), None), None), [], [])) synchronizes.append(synchronize) node.body.extend(synchronizes) return node
def create_while_node(condition_name, body_name, loop_var_names): while_args = [] while_args.append( gast.Name(id=condition_name, ctx=gast.Param(), annotation=None, type_comment=None)) while_args.append( gast.Name(id=body_name, ctx=gast.Param(), annotation=None, type_comment=None)) assign_targets = [ gast.Name(id=var_name, ctx=gast.Param(), annotation=None, type_comment=None) for var_name in loop_var_names ] while_args.append(gast.List(elts=assign_targets, ctx=gast.Param())) while_func_id = gast.parse('fluid.layers.while_loop').body[0].value while_node = gast.Call(func=while_func_id, args=while_args, keywords=[]) assign_node = gast.Assign( targets=[gast.Tuple(elts=assign_targets, ctx=gast.Store())], value=while_node) return assign_node
def visit_For(self, node): target = node.target if isinstance(target, ast.Tuple) or isinstance(target, ast.List): renamings = OrderedDict() self.traverse_tuples(target, (), renamings) if renamings: gtarget = self.get_new_id() node.target = ast.Name(gtarget, node.target.ctx, None) for rename, state in renamings.items(): nnode = reduce( lambda x, y: ast.Subscript( x, ast.Index(ast.Num(y)), ast.Load()), state, ast.Name(gtarget, ast.Load(), None)) if isinstance(rename, str): node.body.insert(0, ast.Assign( [ast.Name(rename, ast.Store(), None)], nnode) ) else: node.body.insert(0, ast.Assign([rename], nnode)) self.generic_visit(node) return node
def visit_FunctionDef(self, node): self.update = True if MODULES['functools'] not in self.global_declarations.values(): import_ = ast.Import([ast.alias('functools', mangle('functools'))]) self.ctx.module.body.insert(0, import_) functools_module = MODULES['functools'] self.global_declarations[mangle('functools')] = functools_module self.ctx.module.body.append(node) former_name = node.name seed = 0 new_name = "pythran_{}{}" while new_name.format(former_name, seed) in self.identifiers: seed += 1 new_name = new_name.format(former_name, seed) self.identifiers.add(new_name) ii = self.gather(ImportedIds, node) binded_args = [ ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii) ] node.args.args = ( [ast.Name(iin, ast.Param(), None, None) for iin in sorted(ii)] + node.args.args) metadata.add(node, metadata.Local()) class Renamer(ast.NodeTransformer): def visit_Call(self, node): self.generic_visit(node) if (isinstance(node.func, ast.Name) and node.func.id == former_name): node.func.id = new_name node.args = ([ ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii) ] + node.args) return node Renamer().visit(node) node.name = new_name self.global_declarations[node.name] = node proxy_call = ast.Name(new_name, ast.Load(), None, None) new_node = ast.Assign([ast.Name(former_name, ast.Store(), None, None)], ast.Call( ast.Attribute( ast.Name(mangle('functools'), ast.Load(), None, None), "partial", ast.Load()), [proxy_call] + binded_args, [], )) self.generic_visit(node) return new_node
def visit_Assign(self, node): self.src = quoting.unquote(node) self.mark(node) self.trivializing = True self.namer.target = node.targets[0] if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)): node.value = self.trivialize(node.value) node.targets[0] = self.visit(node.targets[0]) elif isinstance(node.targets[0], gast.Tuple): node.value = self.visit(node.value) name = self.namer.name(node.targets[0]) target = gast.Name(id=name, ctx=gast.Store(), annotation=None) for i, elt in enumerate(node.targets[0].elts): stmt = gast.Assign(targets=[elt], value=gast.Subscript( value=gast.Name(id=name, ctx=gast.Load(), annotation=None), slice=gast.Index(value=gast.Num(n=i)), ctx=gast.Load())) self.mark(stmt) self.append(stmt) node.targets[0] = target elif not isinstance(node.targets[0], gast.Name): raise ValueError node = self.generic_visit(node) self.namer.target = None self.trivializing = False return node
def make_control_flow_handlers(self, cont_n, status_n, expected_return, has_cont, has_break): ''' Create the statements in charge of gathering control flow information for the static_if result, and executes the expected control flow instruction ''' if expected_return: assign = cont_ass = [ast.Assign( [ast.Tuple(expected_return, ast.Store())], ast.Name(cont_n, ast.Load(), None, None), None)] else: assign = cont_ass = [] if has_cont: cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(LOOP_CONT, None)]) cont_ass = [ast.If(cmpr, deepcopy(assign) + [ast.Continue()], cont_ass)] if has_break: cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(LOOP_BREAK, None)]) cont_ass = [ast.If(cmpr, deepcopy(assign) + [ast.Break()], cont_ass)] return cont_ass
def _build_enum_increase_node(self): return gast.AugAssign(target=gast.Name(id=self.enum_idx_name, ctx=gast.Store(), annotation=None, type_comment=None), op=gast.Add(), value=gast.Constant(value=1, kind=None))
def create_assign_node(name, node): """ Creates a `gast.Assign` node by given name_id as target and node as value. """ targets = generate_name_node(name, ctx=gast.Store()) assign_node = gast.Assign(targets=[targets], value=node) return targets, assign_node
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve_entity(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: new_name = namer.function_name(f.__name__) else: new_name = f.__name__ entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) elif do_rename: node.name = new_name else: assert node.name == new_name return (node, ), new_name, entity_info
def _process_tuple_assignment(self, source, t): for i, e in enumerate(t.elts): if isinstance(e, gast.Tuple): self._process_tuple_assignment(source, e) else: self.scope.setval( anno.getanno(e, anno.Basic.QN), gast.Subscript(source, gast.Index(i), ctx=gast.Store()))