def visit_Lambda(self, node): with self.state[_Function] as fn_scope: node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if fn_scope.level > 2: return templates.replace_as_expression( 'ag__.autograph_artifact(l)', l=node) scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) fn_scope.context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options(fn_scope).to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) return node
def visit_Lambda(self, node): with self.state[_Function] as fn_scope: node = self.generic_visit(node) # TODO(mdan): Fix the tests so that we can always add this decorator. if fn_scope.level > 2: return templates.replace_as_expression( 'ag__.autograph_artifact(l)', l=node) scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol('lscope', scope.referenced) fn_scope.context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options(fn_scope).to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) return node
def visit_Subscript(self, node): node = self.generic_visit(node) s = node.slice if isinstance(s, (gast.Tuple, gast.Slice)): return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node dtype = self.get_definition_directive( node.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.get_item( target, key, opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression(template, target=node.value, key=s, dtype=dtype)
def test_replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template)
def test_replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template)
def visit_Name(self, node): # Only the loads which existed in the original code are overloaded. if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS): return node if isinstance(node.ctx, gast.Load): node = templates.replace_as_expression('ag__.ld(var_)', var_=node) return node
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if not break_used: template = """ for target in iter_: body orelse """ node = templates.replace( template, iter_=node.iter, target=node.target, body=node.body, orelse=node.orelse) new_for_node = node[0] anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression( 'ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body orelse """ node = templates.replace( template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node
def visit_Lambda(self, node): self.state[_Function].enter() node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if self.state[_Function].level > 2: self.state[_Function].exit() return node scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self.ctx.program.options.to_ast(), function_context=function_context_name, function_context_name=gast.Str(function_context_name), body=node.body) self.state[_Function].exit() return node
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) template = """ ag__.converted_call(func, owner, options, args) """ if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def visit_IfExp(self, node): return templates.replace_as_expression( '''ag__.if_stmt(test, lambda: true_expr, lambda: false_expr, lambda: (), lambda _: None)''', test=node.test, true_expr=node.body, false_expr=node.orelse)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace_as_expression( 'func_name', func_name=new_name) return node
def _as_binary_operation(self, op, arg1, arg2): template = templates.replace_as_expression( 'arg1 is arg2', # Note: `is` will be replaced with `op` below. arg1=arg1, arg2=arg2) template.ops[0] = op return template
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) template = """ ag__.converted_call(func, owner, options, args) """ if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def _as_function(self, func_name, args): template = """ func_name(args) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def _as_function(self, func_name, args): template = """ func_name(args) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def test_replace_as_expression(self): template = """ foo(a) """ node = templates.replace_as_expression(template, foo='bar', a='baz') self.assertIsInstance(node, gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.args[0].id, 'baz')
def _replace_stack_call(self, node): assert len(node.args) == 1 dtype = self.get_definition_directive( node.args[0], directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.list_stack( target, opts=ag__.ListStackOpts( element_dtype=dtype, original_call=orig_call)) """ return templates.replace_as_expression(template, dtype=dtype, target=node.args[0], orig_call=node.func)
def _to_reference(self, node): if isinstance(node, (gast.Name, qual_names.QN)): return templates.replace_as_expression( '_tfp_autobatching_context_.var.name', name=node) elif gast_util.is_literal(node): raise ValueError('TODO(axch): Support literals, not just variables') else: msg = 'Expected trivial node, got {}. Is the input in A-normal form?' raise ValueError(msg.format(node))
def test_replace_as_expression(self): template = """ foo(a) """ node = templates.replace_as_expression(template, foo='bar', a='baz') self.assertIsInstance(node, gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.args[0].id, 'baz')
def visit_If(self, node): """Intercepts if statements. Converts each `if` to up to two separate `with` statements, `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`. If the incoming `if` had one arm, returns the transformed AST node; if it had two, returns two nodes in a list. Args: node: An `ast.AST` node representing the `if` statement to convert. Returns: then_node: A node representing the `with`-guarded consequent branch. else_node: A node representing the `with`-guarded alternate branch, if present. """ # Transform a branch # NOTE: this is a little hackery to make sure that prepending works # properly. Wrapping a list of statements in a Module ensures # that the AST-visiting machinery won't choke on, e.g., a list. then = self.generic_visit(gast_util.Module(node.body)).body # Construct header (goes in the `with`s). then_header = templates.replace_as_expression( '_tfp_autobatching_context_.if_(cond)', cond=self._to_reference(node.test)) # Construct `with` node. # TODO(axch): Test that this form actually works with multiline bodies. then_node = templates.replace('with header: body', header=then_header, body=then)[0] if node.orelse: orelse = self.generic_visit(gast_util.Module(node.orelse)).body orelse_header = templates.replace_as_expression( '_tfp_autobatching_context_.else_()') orelse_node = templates.replace('with header: body', header=orelse_header, body=orelse)[0] # Return both return [then_node, orelse_node] else: return then_node
def _convert_builtin(self, f, args, as_expression): template = """ ag__.func(args) """ if as_expression: return templates.replace_as_expression( template, func=py_builtins.overload_of(f).__name__, args=args) else: return templates.replace( template, func=py_builtins.overload_of(f).__name__, args=args)
def _replace_stack_call(self, node): assert len(node.args) == 1 dtype = self.get_definition_directive( node.args[0], directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.list_stack( target, opts=ag__.ListStackOpts( element_dtype=dtype, original_call=orig_call)) """ return templates.replace_as_expression( template, dtype=dtype, target=node.args[0], orig_call=node.func)
def _wrap_to_py_func_single_return(self, node, dtype): # TODO(mdan): Properly handle varargs, etc. template = """ ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False) """ return templates.replace_as_expression( template, func=node.func, dtype=parser.parse_expression(dtype), args=node.args, kwargs=ast_util.keywords_to_dict(node.keywords))
def _convert_builtin(self, f, args, as_expression): template = """ ag__.func(args) """ if as_expression: return templates.replace_as_expression( template, func=py_builtins.overload_of(f).__name__, args=args) else: return templates.replace(template, func=py_builtins.overload_of(f).__name__, args=args)
def visit_For(self, node): node.iter = self.visit(node.iter) node.target = self.visit(node.target) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Return].used: extra_test = anno.getanno(node, 'extra_test', default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', extra_test=extra_test, control_var=self.state[_Function].do_return_var_name) else: extra_test = templates.replace_as_expression( 'ag__.not_(control_var)', control_var=self.state[_Function].do_return_var_name) anno.setanno(node, 'extra_test', extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node
def visit_For(self, node): node.iter = self.visit(node.iter) node.target = self.visit(node.target) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Block].return_used: extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'not control_var and extra_test', extra_test=extra_test, control_var=self.state[_Function].do_return_var_name) else: extra_test = templates.replace_as_expression( 'not control_var', control_var=self.state[_Function].do_return_var_name) anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node
def visit_For(self, node): node.iter = self.visit(node.iter) node.target = self.visit(node.target) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Block].return_used: extra_test = anno.getanno(node, 'extra_test', default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', extra_test=extra_test, control_var=self.state[_Function].do_return_var_name) else: extra_test = templates.replace_as_expression( 'ag__.not_(control_var)', control_var=self.state[_Function].do_return_var_name) anno.setanno(node, 'extra_test', extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node
def _as_function(self, func_name, args, args_as_lambda=False): if args_as_lambda: args_as_lambda = [] for arg in args: template = """ lambda: arg """ args_as_lambda.append( templates.replace_as_expression(template, arg=arg)) args = args_as_lambda if not args: template = """ func_name() """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name)) elif len(args) == 1: template = """ func_name(arg) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), arg=args[0]) elif len(args) == 2: template = """ func_name(arg1, arg2) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), arg1=args[0], arg2=args[1]) else: raise NotImplementedError('{} arguments for {}'.format( len(args), func_name)) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def visit_While(self, node): node.test = self.visit(node.test) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Return].used: node.test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: test)', test=node.test, control_var=self.state[_Function].do_return_var_name) node.orelse = self._visit_statement_block(node, node.orelse) return node
def visit_Return(self, node): """Intercepts return statements. Args: node: An `ast.AST` node representing the `return` statement to convert. Returns: node: A node representing the result. """ node = templates.replace_as_expression( '_tfp_autobatching_context_.return_(value)', value=self._to_reference(node.value)) return gast.Expr(node)
def _assignment_construct_recur(self, target): if isinstance(target, (gast.Tuple, gast.List)): subs = [self._assignment_construct_recur(t) for t in target.elts] if isinstance(target, gast.Tuple): # Context is not Store anymore, because this section is constructing the # pattern object return gast.Tuple(subs, ctx=gast.Load()) else: # Context is not Store anymore, because this section is constructing the # pattern object return gast.List(subs, ctx=gast.Load()) return templates.replace_as_expression( '_tfp_autobatching_context_.var.name', name=target)
def visit_While(self, node): node.test = self.visit(node.test) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Block].return_used: node.test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: test)', test=node.test, control_var=self.state[_Function].do_return_var_name) node.orelse = self._visit_statement_block(node, node.orelse) return node
def _as_function(self, func_name, args, args_as_lambda=False): if args_as_lambda: args_as_lambda = [] for arg in args: template = """ lambda: arg """ args_as_lambda.append( templates.replace_as_expression(template, arg=arg)) args = args_as_lambda if not args: template = """ func_name() """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name)) elif len(args) == 1: template = """ func_name(arg) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), arg=args[0]) elif len(args) == 2: template = """ func_name(arg1, arg2) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), arg1=args[0], arg2=args[1]) else: raise NotImplementedError('{} arguments for {}'.format( len(args), func_name)) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def visit_IfExp(self, node): if anno.hasanno(node.test, anno.Basic.QN): name_root = anno.getanno(node.test, anno.Basic.QN).ssf() else: name_root = 'ifexp' true_fn_name = self._create_branch(node.body, '%s_true' % name_root) false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root) return templates.replace_as_expression( 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)', test=node.test, true_fn_name=true_fn_name, false_fn_name=false_fn_name)
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace( template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace( template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def visit_Call(self, node): full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) function_context_name = self.state[_Function].context_name node = self.generic_visit(node) # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). if full_name.startswith('ag__.'): return node # Calls to the function context manager (inserted by function_scopes) are # also safe. if full_name.startswith(function_context_name + '.'): return node # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. # TODO(mdan): Generalize this to a "static whitelist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): global set_trace_warned if not set_trace_warned: # TODO(mdan): Update and shorten once available on tensorflow.org. ag_logging.warn( 'Detected `pdb.set_trace()` in converted code. The code' ' generated by AutoGraph is not optimized for step-by-step' ' debugging. See https://github.com/tensorflow/tensorflow/' 'blob/master/tensorflow/python/autograph/g3doc/reference/' 'debugging.md.') set_trace_warned = True return node if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return node template = """ ag__.converted_call(func, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, func=node.func, args=self._args_to_tuple(node), kwargs=self._kwargs_to_dict(node), function_ctx=function_context_name) return new_call
def visit_Print(self, node): node = self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts template = """ ag__.converted_call(func, None, options, args, {}) """ return templates.replace_as_expression( template, func='print', options=self.ctx.program.options.to_ast(), args=args)
def visit_IfExp(self, node): if anno.hasanno(node.test, anno.Basic.QN): name_root = anno.getanno(node.test, anno.Basic.QN).ssf() else: name_root = 'ifexp' true_fn_name = self._create_branch(node.body, '%s_true' % name_root) false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root) return templates.replace_as_expression( 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)', test=node.test, true_fn_name=true_fn_name, false_fn_name=false_fn_name)
def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node dtype = self.get_definition_directive( node.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.get_item( target, key, opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression( template, target=node.value, key=node.slice.value, dtype=dtype)
def visit_Print(self, node): node = self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts template = """ ag__.converted_call(func, None, options, args, {}) """ return templates.replace_as_expression( template, func='print', options=self.ctx.program.options.to_ast(), args=args)
def visit_IfExp(self, node): template = ''' ag__.if_exp( test, lambda: true_expr, lambda: false_expr, expr_repr) ''' expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip() return templates.replace_as_expression(template, test=node.test, true_expr=node.body, false_expr=node.orelse, expr_repr=gast.Constant( expr_repr, kind=None))
def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression( 'not var_name', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = tf.constant(False) for target in iter_: (var_name,) body else: orelse """ node = templates.replace( template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) anno.setanno(node[1], 'extra_test', extra_test) return node
def _replace_pop_call(self, node): # Expressions that use pop() are converted to a statement + expression. # # For example: # # print(target.pop()) # # ... is converted to: # # target, target_pop = ag__.list_pop(target) # print(target_pop) # # Here, we just generate the variable name and swap it in, # and _generate_pop_operation will handle the rest. # # Multiple uses of pop() are allowed: # # print(tartget.pop(), target.pop()) # print(tartget.pop().pop()) # assert isinstance(node.func, gast.Attribute) scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) stmt = self.state[_Statement] if stmt.pop_uses is None: stmt.pop_uses = [] stmt.pop_uses.append((node, pop_var_name)) return templates.replace_as_expression('var_name', var_name=pop_var_name)
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ ag__.converted_call(func, owner, options, args) """ if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( self.ctx.info.namespace, internal_convert_user_code=self.ctx.program.options.recursive), args=node.args) # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def _replace_pop_call(self, node): # Expressions that use pop() are converted to a statement + expression. # # For example: # # print(target.pop()) # # ... is converted to: # # target, target_pop = ag__.list_pop(target) # print(target_pop) # # Here, we just generate the variable name and swap it in, # and _generate_pop_operation will handle the rest. # # Multiple uses of pop() are allowed: # # print(tartget.pop(), target.pop()) # print(tartget.pop().pop()) # assert isinstance(node.func, gast.Attribute) scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) pop_uses = self.get_local(POP_USES, []) pop_uses.append((node, pop_var_name)) self.set_local(POP_USES, pop_uses) return templates.replace_as_expression('var_name', var_name=pop_var_name)
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') starred_arg = None normal_args = [] for a in node.args: if isinstance(a, gast.Starred): assert starred_arg is None, 'Multiple *args should be impossible.' starred_arg = a else: a = self.visit(a) normal_args.append(a) if starred_arg is None: args = templates.replace_as_expression('(args,)', args=normal_args) else: args = templates.replace_as_expression( '(args,) + tuple(stararg)', stararg=starred_arg.value, args=normal_args) kwargs_arg = None normal_keywords = [] for k in node.keywords: if k.arg is None: assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' kwargs_arg = k else: k = self.visit(k) normal_keywords.append(k) if kwargs_arg is None: kwargs = ast_util.keywords_to_dict(normal_keywords) else: kwargs = templates.replace_as_expression( 'dict(kwargs, **keywords)', kwargs=kwargs_arg.value, keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ ag__.converted_call(func, owner, options, args, kwargs) """ new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( internal_convert_user_code=self.ctx.program.options.recursive), args=args, kwargs=kwargs) return new_call
def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple( s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple( s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name) return (undefined_assigns + cond_assign + body_def + orelse_def + cond_expr)
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbols(created_in_body), self._fmt_symbols(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = (templates.replace_as_expression('tf.constant(1)'),) returned_from_orelse = ( templates.replace_as_expression('tf.constant(1)'),) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_IfExp(self, node): return templates.replace_as_expression( 'ag__.if_stmt(test, lambda: true_expr, lambda: false_expr)', test=node.test, true_expr=node.body, false_expr=node.orelse)
def test_function_call_in_list(self): template = """ foo(bar) """ source = parser.parse_expression('[a(b(1))]') templates.replace_as_expression(template, bar=source)
def visit_List(self, node): node = self.generic_visit(node) template = """ ag__.new_list(elements) """ return templates.replace_as_expression(template, elements=node)