def visit_If(self, node): # Depth-first traversal of if statements node = self.generic_visit(node) # We check if both branches return, and if so, lift the return out of the # conditional. We don't enforce that the true and false branches either # both return or both do not, because FoldElse might move a return # into a branch after this transform completes. FoldElse and LiftReturn # are alternately run until the code reaches a fixed point. true_branch_returns = isinstance(node.body[-1], gast.Return) false_branch_returns = len(node.orelse) and isinstance( node.orelse[-1], gast.Return) if true_branch_returns and false_branch_returns: node.body[-1] = templates.replace('a = b', a=self.common_return_name, b=node.body[-1].value)[0] node.orelse[-1] = templates.replace('a = b', a=self.common_return_name, b=node.orelse[-1].value)[0] return_node = templates.replace('return a', a=self.common_return_name)[0] self.changes_made = True return [node, return_node] else: return node
def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return (returns,) """ return templates.replace(template, body_name=body_name, body=body, aliased_orig_names=aliased_orig_names, aliased_new_names=aliased_new_names, returns=returns) else: template = """ def body_name(): body return (returns,) """ return templates.replace(template, body_name=body_name, body=body, returns=returns)
def replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template) with self.assertRaises(ValueError): templates.replace('') with self.assertRaises(ValueError): templates.replace('a = b')
def replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template) with self.assertRaises(ValueError): templates.replace('') with self.assertRaises(ValueError): templates.replace('a = b')
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` # Or maybe we should replace range with tf.range? if anno.hasanno(node, 'extra_cond'): def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument i = 0 n = len(loop_iter) # pylint:disable=undefined-variable while i < n and extra_cond: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=gast.Name( self.namer.new_symbol('i', body_scope.referenced), None, None), n=gast.Name( self.namer.new_symbol('n', body_scope.referenced), None, None), extra_cond=anno.getanno(node, 'extra_cond')) else: def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument i = 0 n = len(loop_iter) # pylint:disable=undefined-variable while i < n: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=gast.Name( self.namer.new_symbol('i', body_scope.referenced), None, None), n=gast.Name( self.namer.new_symbol('n', body_scope.referenced), None, None))
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ py2tf_api.converted_call(func, True, False, {}, original_args) """ call_expr = templates.replace( template, func=node.func, original_args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def _gate_symbols(self, guard_statement, guarded_args): template = """ (args,) = (tf.identity(a) for a in (args,)) """ guards = templates.replace(template, args=tuple(guarded_args)) guard_statement.body.extend(guards) return guard_statement
def _convert_len(self, node): def template(args): tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned new_call = templates.replace(template, args=node.args)[0].value return new_call
def _create_break_trigger(self): template = """ var_name = True """ block = templates.replace(template, var_name=self.break_uses[-1][1]) block.append(gast.Continue()) return block
def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=wrapper_name, original_args=gast.List(elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def _create_continuation_init(self): template = """ var_name = False """ assign, = templates.replace( template, var_name=self.continuation_uses[-1][1]) return assign
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) template = """ with py2tf_utils.control_dependency_on_returns(tf, call): # TODO(mdan): Also insert ops to re-fetch if variables are involved? pass # Will be removed below. """ # TODO(mdan): This is brittle. Reorganize the mechanism. statements = templates.replace(template, call=node.value) control_deps_guard = statements[-1] control_deps_guard.body = [] # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) guarded_args = tuple(args_scope.used & (args_scope.parent.modified | args_scope.parent.returned)) if guarded_args: node = tuple(statements[:-1]) + (self._gate_symbols( control_deps_guard, guarded_args), ) else: node = tuple(statements[:-1]) # The mechanism will insert the guard statement later. self.indent_next = True self.next_indent_owner = control_deps_guard return node
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') body_closure = tuple(body_scope.modified - body_scope.created) if len(body_closure) == 1: state = body_closure[0] state_ast_tuple = state else: state = tuple(body_closure) state_ast_tuple = gast.Tuple( tuple(gast.Name(n, None, None) for n in state), None) template = """ def test_name(state): return test def body_name(state): body return state, state_ast_tuple = tf.while_loop(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ast_tuple=state_ast_tuple, test_name=self.namer.new_symbol('loop_test', body_scope.referenced), test=node.test, body_name=self.namer.new_symbol('loop_body', body_scope.referenced), body=node.body) return node
def _wrap_to_py_func_no_return(self, node): args_scope = anno.getanno(node, 'args_scope') # TODO(mdan): Properly handle varargs, kwargs, etc. args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used) # pylint:disable=undefined-variable,unused-argument,function-redefined def template(call, wrapper, args): def wrapper(args): call(args) return 1 tf.py_func(wrapper, [args], [tf.int64]) # pylint:enable=undefined-variable,unused-argument,function-redefined wrapper_name = self.namer.compiled_function_name(node.func.id) wrapper_def, call_expr = templates.replace(template, call=node.func, wrapper=gast.Name( wrapper_name, gast.Load(), None), args=args) anno.setanno(call_expr.value, 'args_scope', args_scope) anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr)
def _create_continuation_trigger(self): template = """ var_name = True """ assign, = templates.replace( template, var_name=self.continuation_uses[-1][1]) return assign
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ tf.Assert(test, [msg]) """ if node.msg is None: return templates.replace( template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError('Can only convert string messages for now.')
def _gate_symbols(self, guard_statement, guarded_args): template = """ (args,) = (tf.identity(a) for a in (args,)) """ guards = templates.replace(template, args=tuple(guarded_args)) guard_statement.body.extend(guards) return guard_statement
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def _wrap_to_py_func_no_return(self, node): func_qn = anno.getanno(node.func, anno.Basic.QN) args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), args_scope.referenced) wrapper_args = [] for arg in node.args: if anno.hasanno(arg, anno.Basic.QN): arg_qn = anno.getanno(arg, anno.Basic.QN) else: arg_qn = qual_names.QN('arg') wrapper_args.append( self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. # TODO(mdan): This is best handled as a dynamic dispatch. # That way we can separate tensors from non-tensor args. template = """ def wrapper(wrapper_args): call(wrapper_args) return 1 tf.py_func(wrapper, original_args, [tf.int64]) """ wrapper_def, call_expr = templates.replace(template, call=node.func, wrapper=wrapper_name, original_args=gast.List( elts=node.args, ctx=None), wrapper_args=wrapper_args) anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) return (wrapper_def, call_expr)
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` # Or maybe we should replace range with tf.range? def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument i = 0 n = len(loop_iter) # pylint:disable=undefined-variable while i < n: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=gast.Name( self.namer.new_symbol('i', body_scope.referenced), None, None), n=gast.Name( self.namer.new_symbol('n', body_scope.referenced), None, None))
def _create_continuation_trigger(self): template = """ var_name = True """ assign, = templates.replace(template, var_name=self.continuation_uses[-1][1]) return assign
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ tf.Assert(test, [tf.constant(msg)]) """ if node.msg is None: return templates.replace( template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError('Can only convert string messages for now.')
def _create_continuation_init(self): template = """ var_name = False """ assign, = templates.replace(template, var_name=self.continuation_uses[-1][1]) return assign
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) template = """ with py2tf_utils.control_dependency_on_returns(tf, call): # TODO(mdan): Also insert ops to re-fetch if variables are involved? pass # Will be removed below. """ # TODO(mdan): This is brittle. Reorganize the mechanism. statements = templates.replace(template, call=node.value) control_deps_guard = statements[-1] control_deps_guard.body = [] # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, 'args_scope') guarded_args = tuple(args_scope.used & (args_scope.parent.modified | args_scope.parent.returned)) if guarded_args: node = tuple(statements[:-1]) + ( self._gate_symbols(control_deps_guard, guarded_args),) else: node = tuple(statements[:-1]) # The mechanism will insert the guard statement later. self.indent_next = True self.next_indent_owner = control_deps_guard return node
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, 'body_scope') body_closure = tuple(body_scope.modified - body_scope.created) if len(body_closure) == 1: state = body_closure[0] state_ast_tuple = state else: state = tuple(body_closure) state_ast_tuple = gast.Tuple( tuple(gast.Name(n, None, None) for n in state), None) template = """ def test_name(state): return test def body_name(state): body return state, state_ast_tuple = tf.while_loop(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ast_tuple=state_ast_tuple, test_name=self.namer.new_symbol('loop_test', body_scope.referenced), test=node.test, body_name=self.namer.new_symbol('loop_body', body_scope.referenced), body=node.body) return node
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: owner_type = self._determine_function_owner(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ py2tf_api.converted_call(func, True, False, {}, original_args) """ call_expr = templates.replace(template, func=node.func, original_args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def _create_break_trigger(self): template = """ var_name = True """ block = templates.replace(template, var_name=self.break_uses[-1][1]) block.append(gast.Continue()) return block
def visit_With(self, node): # Depth-first traversal of syntax node = self.generic_visit(node) # If the with statement returns, lift the return if isinstance(node.body[-1], gast.Return): node.body[-1] = templates.replace('a = b', a=self.common_return_name, b=node.body[-1].value)[0] return_node = templates.replace('return a', a=self.common_return_name)[0] node = self.generic_visit(node) self.changes_made = True return [node, return_node] else: return node
def _wrap_to_py_func_no_return(self, node): args_scope = anno.getanno(node, 'args_scope') # TODO(mdan): Properly handle varargs, kwargs, etc. args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used) # pylint:disable=undefined-variable,unused-argument,function-redefined def template(call, wrapper, args): def wrapper(args): call(args) return 1 tf.py_func(wrapper, [args], [tf.int64]) # pylint:enable=undefined-variable,unused-argument,function-redefined wrapper_name = self.namer.compiled_function_name(node.func.id) wrapper_def, call_expr = templates.replace( template, call=node.func, wrapper=gast.Name(wrapper_name, gast.Load(), None), args=args) anno.setanno(call_expr.value, 'args_scope', args_scope) anno.setanno(wrapper_def, 'skip_processing', True) return (wrapper_def, call_expr)
def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ results = py2tf_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, results=results, body_name=body_name, orelse_name=orelse_name) else: template = """ py2tf_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name)
def _create_continuation_check(self): template = """ if not var_name: pass """ cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) cond.body = [] return cond
def _create_break_check(self): def template(var_name): (not var_name) # pylint:disable=pointless-statement expr, = templates.replace(template, var_name=gast.Name(self.break_uses[-1][1], None, None)) return expr.value
def _create_break_init(self): def template(var_name): # pylint:disable=unused-argument var_name = False assign, = templates.replace(template, var_name=gast.Name(self.break_uses[-1][1], None, None)) return assign
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) i_var = self.context.namer.new_symbol('i', body_scope.referenced) n_var = self.context.namer.new_symbol('n', body_scope.referenced) iterated_var = self.context.namer.new_symbol('iterated', body_scope.referenced) # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 iterated = loop_iter n = len(iterated) while i < n and extra_cond: target = iterated[i] body i += 1 """ return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, n=n_var, iterated=iterated_var, extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ i = 0 iterated = loop_iter n = len(iterated) while i < n: target = iterated[i] body i += 1 """ repl = templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, n=n_var, iterated=iterated_var) return repl
def _create_continuation_trigger(self): def template(var_name): # pylint:disable=unused-argument var_name = True assign, = templates.replace( template, var_name=gast.Name(self.continuation_uses[-1][1], None, None)) return assign
def _gate_symbols(self, guard_statement, guarded_args): # TODO(mdan): This won't work for variables. template = """ (args,) = (tf.identity(a) for a in (args,)) """ guards = templates.replace(template, args=tuple(guarded_args)) guard_statement.body.extend(guards) return guard_statement
def _create_break_check(self): def template(var_name): (not var_name) # pylint:disable=pointless-statement expr, = templates.replace( template, var_name=gast.Name(self.break_uses[-1][1], None, None)) return expr.value
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) i_var = self.context.namer.new_symbol('i', body_scope.referenced) smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter', body_scope.referenced) cont_var = self.context.namer.new_symbol('cont', body_scope.referenced) # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) while cont and extra_cond: body i += 1 cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, smart_loop_iter=smart_loop_iter_var, cont=cont_var, extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ i = 0 smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) while cont: body i += 1 cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ repl = templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, smart_loop_iter=smart_loop_iter_var, cont=cont_var) return repl
def _create_break_init(self): def template(var_name): # pylint:disable=unused-argument var_name = False assign, = templates.replace( template, var_name=gast.Name(self.break_uses[-1][1], None, None)) return assign
def _wrap_to_py_func_no_return(self, node): # TODO(mdan): Properly handle varargs, kwargs, etc. template = """ py2tf_utils.wrap_py_func(func, None, (original_args,), True) """ return templates.replace(template, func=node.func, original_args=node.args)
def _gate_symbols(self, guard_statement, guarded_args): # TODO(mdan): This won't work for variables. template = """ (args,) = (tf.identity(a) for a in (args,)) """ guards = templates.replace(template, args=tuple(guarded_args)) guard_statement.body.extend(guards) return guard_statement
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) i_var = self.context.namer.new_symbol('i', body_scope.referenced) smart_loop_iter_var = self.context.namer.new_symbol( 'smart_loop_iter', body_scope.referenced) cont_var = self.context.namer.new_symbol('cont', body_scope.referenced) # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) while cont and extra_cond: body i += 1 cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ return templates.replace(template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, smart_loop_iter=smart_loop_iter_var, cont=cont_var, extra_cond=anno.getanno( node, 'extra_cond')) else: template = """ i = 0 smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) while cont: body i += 1 cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ repl = templates.replace(template, loop_iter=node.iter, target=node.target, body=node.body, i=i_var, smart_loop_iter=smart_loop_iter_var, cont=cont_var) return repl
def replace_as_expression(self): template = """ foo(a) """ node = templates.replace(template, foo='bar', a='baz') self.assertTrue(node is gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz')
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def canonicalize_listcomp(self, result_node, list_comp_node): make_list = templates.replace('list_ = create_list', list_=result_node, create_list=self.instantiate_list_node()) loop_body = self.make_update_list_node(result_node, list_comp_node.elt) for gen in reversed(list_comp_node.generators): for gen_if in reversed(gen.ifs): loop_body = templates.replace('if test: loop_body', test=gen_if, loop_body=loop_body) loop_body = templates.replace('for target in iter_: loop_body', iter_=gen.iter, target=gen.target, loop_body=loop_body) return make_list + loop_body
def _create_break_trigger(self): def template(var_name): # pylint:disable=unused-argument var_name = True block = templates.replace( template, var_name=gast.Name(self.break_uses[-1][1], None, None)) block.append(gast.Continue()) return block
def _inline_tf_op(self, op_name, args): template = """ tf.op_name(args) """ replacement = templates.replace(template, op_name=op_name, args=args) # It's a body with a single expression, we want its value. n = replacement[0].value anno.setanno(n, SAFE_BOOLEAN_OPERAND, True) return n
def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ results = py2tf_utils.run_cond(test, body_name, orelse_name) """ return templates.replace(template, test=test, results=results, body_name=body_name, orelse_name=orelse_name) else: template = """ py2tf_utils.run_cond(test, body_name, orelse_name) """ return templates.replace(template, test=test, body_name=body_name, orelse_name=orelse_name)
def _gate_symbols(self, guard_statement, guarded_args): def template(args): # pylint:disable=unused-argument (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable guards = templates.replace( template, args=tuple(gast.Name(a, None, None) for a in guarded_args)) guard_statement.body.extend(guards) return guard_statement
def _gate_symbols(self, guard_statement, guarded_args): def template(args): # pylint:disable=unused-argument (args, ) = (tf.identity(a) for a in (args, )) # pylint:disable=undefined-variable guards = templates.replace( template, args=tuple(gast.Name(a, None, None) for a in guarded_args)) guard_statement.body.extend(guards) return guard_statement
def _create_continuation_check(self): template = """ if not var_name: pass """ cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) cond.body = [] return cond
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def replace_as_expression(self): template = """ foo(a) """ node = templates.replace(template, foo='bar', a='baz') self.assertTrue(node is gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz')
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body) return node
def _create_continuation_check(self): def template(var_name): if not var_name: pass cond, = templates.replace( template, var_name=gast.Name(self.continuation_uses[-1][1], None, None)) cond.body = [] return cond