def _create_state_functions(self, loop_vars, nonlocal_declarations, getter_name, setter_name): if loop_vars: template = """ def getter_name(): return state_vars, def setter_name(loop_vars): nonlocal_declarations state_vars, = loop_vars """ return templates.replace( template, nonlocal_declarations=nonlocal_declarations, getter_name=getter_name, setter_name=setter_name, state_vars=tuple(loop_vars)) else: template = """ def getter_name(): return () def setter_name(loop_vars): pass """ return templates.replace(template, getter_name=getter_name, setter_name=setter_name)
def _create_cond_expr(self, results, test, body_name, orelse_name, state_getter_name, state_setter_name, basic_symbol_names, composite_symbol_names): if results is not None: template = """ results = ag__.if_stmt(test, body_name, orelse_name, state_getter_name, state_setter_name, (basic_symbol_names,), (composite_symbol_names,)) """ return templates.replace( template, test=test, results=results, body_name=body_name, orelse_name=orelse_name, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names) else: template = """ ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, (basic_symbol_names,), (composite_symbol_names,)) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name, getter_name=state_getter_name, setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names)
def _do_transform_node(self, node): temp_name = self._gensym.new_name() temp_assign = templates.replace('temp_name = expr', temp_name=temp_name, expr=node)[0] self._add_pending_statement(temp_assign) answer = templates.replace('temp_name', temp_name=temp_name)[0] return answer
def create_assignment(self, target, expression): template = """ target = expression """ return templates.replace(template, target=target, expression=expression)
def visit_Return(self, node): if node.value is None: return node return templates.replace( 'return function_context_name.mark_return_value(value)', function_context_name=self.state[_Function].context_name, value=node.value)
def visit_While(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): body else: orelse """ node = templates.replace(template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) new_while_node = node[1] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) return node
def visit_Return(self, node): for block in reversed(self.state[_Block].stack): block.return_used = True block.create_guard_next = True if block.is_function: break retval = node.value if node.value else parser.parse_expression('None') # Note: If `return <expr> raises, then the return is aborted. # The try-catch below ensures the variables remain consistent in that case. template = """ try: do_return_var_name = True retval_var_name = retval except: do_return_var_name = False raise """ node = templates.replace( template, do_return_var_name=self.state[_Function].do_return_var_name, retval_var_name=self.state[_Function].retval_var_name, retval=retval) return node
def to_ast(self): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Returns: ast.Node """ if self == STANDARD_OPTIONS: return parser.parse_expression('ag__.STD') template = """ ag__.ConversionOptions( recursive=recursive_val, user_requested=user_requested_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.{}'.format(str(v)) for v in values))) expr_ast = templates.replace( template, recursive_val=parser.parse_expression(str(self.recursive)), user_requested_val=parser.parse_expression(str(self.user_requested)), internal_convert_user_code_val=parser.parse_expression( str(self.internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def _replace_append_call(self, node): assert len(node.args) == 1 assert isinstance(node.func, gast.Attribute) template = """ target = ag__.list_append(target, element) """ return templates.replace(template, target=node.func.value, element=node.args[0])
def visit_Break(self, node): self.state[_Break].used = True var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ var_name = True continue """ return templates.replace(template, var_name=var_name)
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ ag__.assert_stmt(test, lambda: msg) """ if node.msg is None: return templates.replace(template, test=node.test, msg=gast.Constant('Assertion error', kind=None)) elif isinstance(node.msg, gast.Constant): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError( 'can only convert string messages for now.')
def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None if not isinstance(target.slice, gast.Index): return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( template, target=target.value, key=target.slice.value, item=value)
def visit_FunctionDef(self, node): self.state[_Function].enter() self.state[_Block].enter() self.state[_Block].is_function = True scope = anno.getanno(node, NodeAnno.BODY_SCOPE) do_return_var_name = self.ctx.namer.new_symbol('do_return', scope.referenced) retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) self.state[_Function].do_return_var_name = do_return_var_name self.state[_Function].retval_var_name = retval_var_name converted_body = self._visit_statement_block(node, node.body) # Avoid placing statements before any eventual docstring. # TODO(mdan): Should a docstring even be included in the output? docstring = None if converted_body: if (isinstance(converted_body[0], gast.Expr) and isinstance(converted_body[0].value, gast.Constant)): docstring = converted_body[0] converted_body = converted_body[1:] if self.state[_Block].return_used: if self.default_to_null_return: # TODO(mdan): Remove the (do_return_var_name,) below. # Currently, that line ensures the variable is both defined and alive # throughout the function. template = """ do_return_var_name = False retval_var_name = ag__.UndefinedReturnValue() body (do_return_var_name,) return ag__.retval(retval_var_name) """ else: template = """ body return retval_var_name """ node.body = templates.replace( template, body=converted_body, do_return_var_name=do_return_var_name, retval_var_name=retval_var_name) if docstring: node.body.insert(0, docstring) self.state[_Block].exit() self.state[_Function].exit() return node
def _create_undefined_assigns(self, undefined_symbols): assignments = [] for s in undefined_symbols: template = ''' var = ag__.Undefined(symbol_name) ''' assignments += templates.replace(template, var=s, symbol_name=gast.Constant( s.ssf(), kind=None)) return assignments
def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" if not block: return block template = """ if ag__.not_(var_name): block """ node = templates.replace(template, var_name=var_name, block=block) return node
def visit_While(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) template = """ state_functions def body_name(): nonlocal_declarations body def test_name(): return test undefined_assigns ag__.pt_while_stmt( test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), nonlocal_declarations=nonlocal_declarations, opts=opts, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), test=node.test, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), undefined_assigns=undefined_assigns)
def visit_Assign(self, node): if not isinstance(node.value, gast.ListComp): return self.generic_visit(node) if len(node.targets) > 1: raise NotImplementedError('multiple assignments') target, = node.targets list_comp_node = node.value template = """ target = [] """ initialization = templates.replace(template, target=target) template = """ target.append(elt) """ body = templates.replace(template, target=target, elt=list_comp_node.elt) for gen in reversed(list_comp_node.generators): for gen_if in reversed(gen.ifs): template = """ if test: body """ body = templates.replace(template, test=gen_if, body=body) template = """ for target in iter_: body """ body = templates.replace(template, iter_=gen.iter, target=gen.target, body=body) return initialization + body
def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if len(returns) == 1: template = """ return retval """ return_stmt = templates.replace(template, retval=returns[0]) else: template = """ return (retvals,) """ return_stmt = templates.replace(template, retvals=returns) if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return_stmt """ return templates.replace(template, body_name=body_name, body=body, aliased_orig_names=aliased_orig_names, aliased_new_names=aliased_new_names, return_stmt=return_stmt) else: template = """ def body_name(): body return_stmt """ return templates.replace(template, body_name=body_name, body=body, return_stmt=return_stmt)
def visit_Continue(self, node): self.state[_Continue].used = True for block in reversed(self.state[_Block].stack): # See ContinueCanonicalizationTest.test_multiple_continues for an example # it's necessary to create guards for all enclosing affected blocks, not # just that of the current block. block.create_guard_next = True if block.is_loop_type: # continue only affects the innermost loop break template = """ var_name = True """ return templates.replace( template, var_name=self.state[_Continue].control_var_name)
def _postprocess_statement(self, node): if self.state[_Continue].used: block = self.state[_Block] should_wrap_current = block.create_guard_current # After processing propagate whether to guard the next statement block.create_guard_current = block.create_guard_next block.create_guard_next = False if should_wrap_current: template = """ if ag__.not_(var_name): original_node """ cond, = templates.replace( template, var_name=self.state[_Continue].control_var_name, original_node=node) return cond, cond.body return node, None
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression('ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body else: orelse """ node = templates.replace(template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node
def _postprocess_statement(self, node): if not self.state[_Block].return_used: return node, None state = self.state[_Block] if state.create_guard_now: template = """ if ag__.not_(do_return_var_name): original_node """ cond, = templates.replace( template, do_return_var_name=self.state[_Function].do_return_var_name, original_node=node) node, block = cond, cond.body else: node, block = node, None state.create_guard_now = state.create_guard_next state.create_guard_next = False return node, block
def visit_FunctionDef(self, node): self.state[_Function].enter() scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) function_context_name = self.ctx.namer.new_symbol( 'fscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) node = self.generic_visit(node) docstring_node = None if node.body: first_statement = node.body[0] if (isinstance(first_statement, gast.Expr) and isinstance(first_statement.value, gast.Constant)): docstring_node = first_statement node.body = node.body[1:] template = """ with ag__.FunctionScope( function_name, context_name, options) as function_context: body """ wrapped_body = templates.replace( template, function_name=gast.Constant(node.name, kind=None), context_name=gast.Constant(function_context_name, kind=None), options=self._function_scope_options().to_ast(), function_context=function_context_name, body=node.body) if docstring_node is not None: wrapped_body = [docstring_node] + wrapped_body node.body = wrapped_body self.state[_Function].exit() return node
def _visit_loop_body(self, node, nodes): self.state[_Continue].enter() self.state[_Block].enter() self.state[_Block].is_loop_type = True scope = anno.getanno(node, NodeAnno.BODY_SCOPE) continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) self.state[_Continue].control_var_name = continue_var nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) if self.state[_Continue].used: template = """ var_name = False """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes self.state[_Block].exit() self.state[_Continue].exit() return nodes
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace(template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def visit_For(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified | iter_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(): nonlocal_declarations return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_expr=extra_test, extra_test_name=extra_test_name, loop_vars=loop_vars, nonlocal_declarations=nonlocal_declarations) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) template = """ iterates = iterate_arg_name """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) template = """ state_functions def body_name(iterate_arg_name): nonlocal_declarations iterate_expansion body extra_test_function undefined_assigns ag__.for_stmt( iterated, extra_test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, iterate_expansion=iterate_expansion, iterated=node.iter, nonlocal_declarations=nonlocal_declarations, opts=opts, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns)
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name, factory_name, closure_vars, future_features): """Wraps an AST into the body of a dynamic factory. This uses the dynamic factory (factory of factory) pattern to achieve the following: 1. The inner factory, dynamically creates the entity represented by nodes. 2. The entity is parametrized by `ag__`, the internal AutoGraph module. 3. The outer factory creates the inner factory with a lexical scope in which `closure_vars` are bound local variables. This in turn allows the caller to control the exact closure (i.e. non-global free variables) for the inner factory. The AST is expected to define some symbol named by `entity_name`. Args: nodes: ast.AST entity_name: Union[Text, ast.AST] factory_factory_name: Text factory_name: Text closure_vars: Iterable[Text] future_features: Iterable[Text], see EntityInfo.future_features. Returns: ast.AST """ if not isinstance(nodes, (list, tuple)): nodes = (nodes, ) dummy_closure_defs = [] for var_name in closure_vars: template = """ var_name = None """ dummy_closure_defs.extend( templates.replace(template, var_name=var_name)) if future_features: future_imports = gast.ImportFrom(module='__future__', names=[ gast.alias(name=name, asname=None) for name in future_features ], level=0) else: future_imports = [] # These dummy symbol declarations create local fariables in a function scope, # so that the Python parser correctly marks them as free non-global variables # upon load (that is, it creates cell slots for each symbol). Their values are # not used, as the cells are swapped with the original entity's cells after # the code has been loaded. template = """ future_imports def factory_factory_name(): dummy_closure_defs def factory_name(ag__, ag_source_map__, ag_module__): entity_defs entity_name.ag_source_map = ag_source_map__ entity_name.ag_module = ag_module__ entity_name = ag__.autograph_artifact(entity_name) return entity_name return factory_name """ return templates.replace(template, future_imports=future_imports, factory_factory_name=factory_factory_name, factory_name=factory_name, dummy_closure_defs=dummy_closure_defs, entity_defs=nodes, entity_name=entity_name)