def to_ast(self): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Returns: ast.Node """ if self == STANDARD_OPTIONS: return parser.parse_expression('ag__.STD') template = """ ag__.ConversionOptions( recursive=recursive_val, user_requested=user_requested_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.{}'.format(str(v)) for v in values))) expr_ast = templates.replace( template, recursive_val=parser.parse_expression(str(self.recursive)), user_requested_val=parser.parse_expression(str(self.user_requested)), internal_convert_user_code_val=parser.parse_expression( str(self.internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def visit_FunctionDef(self, node): self.state[_Function].enter() # Note: if the conversion process ever creates helper functions, this # assumption will no longer hold. assert anno.hasanno(node, 'function_context_name'), ( 'The function_scopes converter always creates a scope for functions.' ) self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node.args = self.visit(node.args) node.body = self.visit_block(node.body) if self.state[_Function].level < 2: # Top-level functions lose their decorator because the conversion is # always just-in-time and by the time it happens the decorators are # already set to be applied. node.decorator_list = [] else: # TODO(mdan): Fix the tests so that we can always add this decorator. # Inner functions are converted already, so we insert a decorator to # prevent double conversion. Double conversion would work too, but this # saves the overhead. node.decorator_list.append( parser.parse_expression('ag__.autograph_artifact')) if node.returns: node.returns = self.visit(node.returns) self.state[_Function].exit() return node
def visit_Return(self, node): for block in reversed(self.state[_Block].stack): block.return_used = True block.create_guard_next = True if block.is_function: break retval = node.value if node.value else parser.parse_expression('None') # Note: If `return <expr> raises, then the return is aborted. # The try-catch below ensures the variables remain consistent in that case. template = """ try: do_return_var_name = True retval_var_name = retval except: do_return_var_name = False raise """ node = templates.replace( template, do_return_var_name=self.state[_Function].do_return_var_name, retval_var_name=self.state[_Function].retval_var_name, retval=retval) return node
def _kwargs_to_dict(self, node): """Ties together all keyword and **kwarg arguments in a single dict.""" if node.keywords: return gast.Call(gast.Name('dict', ctx=gast.Load(), annotation=None, type_comment=None), args=(), keywords=node.keywords) else: return parser.parse_expression('None')
def apply_to_single_assignments(targets, values, apply_fn): """Applies a function to each individual assignment. This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. It tries to break down the unpacking if possible. In effect, it has the same effect as passing the assigned values in SSA form to apply_fn. Examples: The following will result in apply_fn(a, c), apply_fn(b, d): a, b = c, d The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): a, b = c The following will result in apply_fn(a, (b, c)): a = b, c It uses the visitor pattern to allow subclasses to process single assignments individually. Args: targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be used with the targets field of an ast.Assign node values: ast.AST apply_fn: Callable[[ast.AST, ast.AST], None], called with the respective nodes of each single assignment """ if not isinstance(targets, (list, tuple)): targets = (targets, ) for target in targets: if isinstance(target, (gast.Tuple, gast.List)): for i in range(len(target.elts)): target_el = target.elts[i] if isinstance(values, (gast.Tuple, gast.List)): value_el = values.elts[i] else: idx = parser.parse_expression(str(i)) value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load()) apply_to_single_assignments(target_el, value_el, apply_fn) else: apply_fn(target, values)
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace(template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def from_str(qn_str): node = parser.parse_expression(qn_str) node = resolve(node) return anno.getanno(node, anno.Basic.QN)
def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.{}'.format(str(v)) for v in values)))
def visit_For(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified | iter_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(): nonlocal_declarations return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_expr=extra_test, extra_test_name=extra_test_name, loop_vars=loop_vars, nonlocal_declarations=nonlocal_declarations) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) template = """ iterates = iterate_arg_name """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) template = """ state_functions def body_name(iterate_arg_name): nonlocal_declarations iterate_expansion body extra_test_function undefined_assigns ag__.for_stmt( iterated, extra_test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, iterate_expansion=iterate_expansion, iterated=node.iter, nonlocal_declarations=nonlocal_declarations, opts=opts, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns)
def _as_unary_function(self, func_name, arg): return templates.replace_as_expression( 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)