def visit_Assign(self, node): if not isinstance(node.value, gast.ListComp): return self.generic_visit(node) if len(node.targets) > 1: raise NotImplementedError('multiple assignments') target, = node.targets list_comp_node = node.value template = """ target = [] """ initialization = templates.replace(template, target=target) template = """ target.append(elt) """ body = templates.replace(template, target=target, elt=list_comp_node.elt) for gen in reversed(list_comp_node.generators): for gen_if in reversed(gen.ifs): template = """ if test: body """ body = templates.replace(template, test=gen_if, body=body) template = """ for target in iter_: body """ body = templates.replace( template, iter_=gen.iter, target=gen.target, body=body) return initialization + body
def replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template) with self.assertRaises(ValueError): templates.replace('') with self.assertRaises(ValueError): templates.replace('a = b')
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def visit_With(self, node): # Depth-first traversal of syntax node = self.generic_visit(node) # If the with statement returns, lift the return if isinstance(node.body[-1], gast.Return): node.body[-1] = templates.replace( 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] return_node = templates.replace('return a', a=self.common_return_name)[0] node = self.generic_visit(node) self.changes_made = True return [node, return_node] else: return node
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.context.namer.new_symbol('break__', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._track_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: node.orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False for_stmt """ # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). node = templates.replace( template, var_name=break_var, for_stmt=node) extra_test = templates.replace_as_expression( 'not var_name', var_name=break_var) anno.setanno(node[1], 'extra_test', extra_test) return node
def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False while test and not var_name: body else: orelse """ node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) return node
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def _create_break_trigger(self): template = """ var_name = True """ block = templates.replace(template, var_name=self.break_uses[-1][1]) block.append(gast.Continue()) return block
def visit_Break(self, node): self.break_uses[-1][0] = True template = """ var_name = True continue """ return templates.replace(template, var_name=self.break_uses[-1][1])
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ tf.Assert(test, (msg,)) """ if node.msg is None: return templates.replace( template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError('can only convert string messages for now.')
def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ var_name = True """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME))
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. dtype = anno.getanno( original_call_node.func.value, 'element_type', default=templates.replace_as_expression('None')) shape = anno.getanno( original_call_node.func.value, 'element_shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace( template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ results = ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, results=results, body_name=body_name, orelse_name=orelse_name) else: template = """ ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name)
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_Break(self, node): self.state[_Break].used = True var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ var_name = True continue """ return templates.replace(template, var_name=var_name)
def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ var_name = True continue """ return templates.replace(template, var_name=var_name)
def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( template, target=target.value, key=target.slice, item=value)
def replace_as_expression(self): template = """ foo(a) """ node = templates.replace(template, foo='bar', a='baz') self.assertTrue(node is gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz')
def _convert_builtin(self, f, args, as_expression): template = """ ag__.func(args) """ if as_expression: return templates.replace_as_expression( template, func=py_builtins.overload_of(f).__name__, args=args) else: return templates.replace( template, func=py_builtins.overload_of(f).__name__, args=args)
def canonicalize_listcomp(self, result_node, list_comp_node): make_list = templates.replace( 'list_ = create_list', list_=result_node, create_list=self.instantiate_list_node()) loop_body = self.make_update_list_node(result_node, list_comp_node.elt) for gen in reversed(list_comp_node.generators): for gen_if in reversed(gen.ifs): loop_body = templates.replace( 'if test: loop_body', test=gen_if, loop_body=loop_body) loop_body = templates.replace( 'for target in iter_: loop_body', iter_=gen.iter, target=gen.target, loop_body=loop_body) return make_list + loop_body
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def _replace_append_call(self, node): assert len(node.args) == 1 assert isinstance(node.func, gast.Attribute) template = """ target = ag__.list_append(target, element) """ return templates.replace( template, target=node.func.value, element=node.args[0])
def _create_branch(self, expr, name_stem): scope = self.state[_Statement].scope name = self.ctx.namer.new_symbol(name_stem, scope.referenced) template = """ def name(): return expr, """ node = templates.replace(template, name=name, expr=expr) self.state[_FunctionDefs].nodes.append(node) return name
def _wrap_to_py_func_no_return(self, node): # TODO(mdan): Properly handle varargs, etc. template = """ ag__.utils.wrap_py_func(func, None, (args,), kwargs, True) """ return templates.replace( template, func=node.func, args=node.args, kwargs=ast_util.keywords_to_dict(node.keywords))
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_cond'): extra_cond = anno.getanno(node, 'extra_cond') extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) else: extra_cond = parser.parse_expression('True') template = """ def extra_cond_name(state_ssf): return extra_cond_expr def body_name(iterate, state_ssf): body return state_ssf, state_ast_tuple = __ops.for_loop( iterated, extra_cond_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iterated=node.iter, iterate=node.target, extra_cond_name=self.context.namer.new_symbol('extra_cond', all_referenced), extra_cond_expr=extra_cond, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_replace_function_name(self): template = """ def fname(a): a += 1 a = 2 * a + 1 return a """ node = templates.replace(template, fname='test_fn')[0] result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2))
def visit_Print(self, node): self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts template = """ fname(args) """ function_call = templates.replace(template, fname='print', args=args)[0] return self.visit(function_call)
def _ensure_node_is_trivial(self, node): if node is None: return node elif isinstance(node, self._trivial_nodes): return node elif isinstance(node, list): # If something's field was actually a list, e.g., variadic arguments. return [self._ensure_node_is_trivial(n) for n in node] elif isinstance(node, gast.keyword): node.value = self._ensure_node_is_trivial(node.value) return node elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): return self._ensure_fields_trivial(node) elif isinstance(node, gast.expr): temp_name = self._gensym.new_name() temp_assign = templates.replace( 'temp_name = expr', temp_name=temp_name, expr=node)[0] self._add_pending_statement(temp_assign) answer = templates.replace('temp_name', temp_name=temp_name)[0] return answer else: raise ValueError('Do not know how to treat {}'.format(node))
def test_replace_complex_context(self): template = """ def test_fn(foo): foo = 0 """ node = templates.replace( template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0] self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) function_call_arg = node.body[0].targets[0].value.args[0] self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None if not isinstance(target.slice, gast.Index): return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace(template, target=target.value, key=target.slice.value, item=value)
def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" if not block: return block template = """ if not var_name: block """ node = templates.replace( template, var_name=var_name, block=block) return node
def visit_If(self, node): # Depth-first traversal of if statements node = self.generic_visit(node) # We check if both branches return, and if so, lift the return out of the # conditional. We don't enforce that the true and false branches either # both return or both do not, because FoldElse might move a return # into a branch after this transform completes. FoldElse and LiftReturn # are alternately run until the code reaches a fixed point. true_branch_returns = isinstance(node.body[-1], gast.Return) false_branch_returns = len(node.orelse) and isinstance( node.orelse[-1], gast.Return) if true_branch_returns and false_branch_returns: node.body[-1] = templates.replace( 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] node.orelse[-1] = templates.replace( 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0] return_node = templates.replace('return a', a=self.common_return_name)[0] self.changes_made = True return [node, return_node] else: return node
def visit_FunctionDef(self, node): self._function_level += 1 try: self.generic_visit(node) finally: self._function_level -= 1 scope_name = node.name if self._function_level == 0 and self.context.owner_type is not None: scope_name = '{}/{}'.format(self.context.owner_type.__name__, scope_name) node.body = templates.replace('with tf.name_scope(scope_name): body', scope_name=gast.Str(scope_name), body=node.body) return node
def test_replace_name_with_call(self): template = """ def test_fn(): b = 5 def g(a): return 3 * a def f(): return g return foo """ source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(15, result.test_fn())
def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return (returns,) """ return templates.replace( template, body_name=body_name, body=body, aliased_orig_names=aliased_orig_names, aliased_new_names=aliased_new_names, returns=returns) else: template = """ def body_name(): body return (returns,) """ return templates.replace( template, body_name=body_name, body=body, returns=returns)
def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" # If we don't have statements that immediately depend on the break # we still need to make sure that the break variable remains # used, in case the break becomes useful in later stages of transformation. # Not having this broke the break_in_inner_loop test. if not block: block = [gast.Pass()] template = """ if not var_name: block """ node = templates.replace(template, var_name=var_name, block=block) return node
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ node = templates.replace( template, block=[ gast.Assign([gast.Name('a', None, None)], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1))
def _visit_loop_body(self, node, nodes): self.enter_local_scope() scope = anno.getanno(node, NodeAnno.BODY_SCOPE) continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) self.set_local(CONTROL_VAR_NAME, continue_var) nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) if self.get_local(CONTINUE_USED, False): template = """ var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes self.exit_local_scope() return nodes
def visit_Expr(self, node): node = self.generic_visit(node) if isinstance(node.value, gast.Call): call_node = node.value if not anno.hasanno(call_node.func, anno.Basic.QN): return node qn = anno.getanno(call_node.func, anno.Basic.QN) if qn.qn[-1] == 'append' and (len(call_node.args) == 1): template = """ target = autograph_utils.dynamic_list_append(target, element) """ node = templates.replace(template, target=qn.parent.ast(), element=call_node.args[0]) return node
def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.context.namer.new_symbol('break_requested', scope.referenced) self.break_uses.append([False, break_var]) node = self.generic_visit(node) if self.break_uses[-1][0]: template = """ var_name = False original_for """ node = templates.replace(template, var_name=break_var, original_for=node) extra_cond = templates.replace_as_expression('not var_name', var_name=break_var) new_for_node = node[1] anno.setanno(new_for_node, 'extra_cond', extra_cond) self.break_uses.pop() return node
def visit_FunctionDef(self, node): node = self.generic_visit(node) unscoped_body = [] scoped_body = node.body if scoped_body: first = scoped_body[0] if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str): # Skip any docstring. unscoped_body = scoped_body[:1] scoped_body = scoped_body[1:] template = """ with tf.name_scope(scope_name): body """ scoped_body = templates.replace( template, scope_name=gast.Str(self._name_for_current_scope()), body=scoped_body) node.body = unscoped_body + scoped_body return node
def _postprocess_statement(self, node): # Example of how the state machine below works: # # 1| stmt # State: CONTINUE_USED = False # | # Action: none # 2| if cond: # 3| continue # State: CONTINUE_USED = True, # | # GUARD_CREATED = False, # | # CREATE_GUARD_NEXT = False # | # Action: set CREATE_GUARD_NEXT = True # 4| stmt # State: CONTINUE_USED = True, # | # GUARD_CREATED = False, # | # CREATE_GUARD_NEXT = True # | # Action: create `if not continue_used`, # | # set GUARD_CREATED = True # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True # | # Action: none (will be wrapped under previously # | # created if node) if self.get_local(CONTINUE_USED, False): if self.get_local(GUARD_CREATED, False): return node, None elif not self.get_local(CREATE_GUARD_NEXT, False): self.set_local(CREATE_GUARD_NEXT, True) return node, None else: self.set_local(GUARD_CREATED, True) template = """ if not var_name: original_node """ cond, = templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME), original_node=node) return cond, cond.body return node, None
def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression('not var_name', var_name=break_var) # The extra code is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = tf.constant(False) for target in iter_: (var_name,) body else: orelse """ node = templates.replace(template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) anno.setanno(node[1], 'extra_test', extra_test) return node
def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.context.namer.new_symbol('break_requested', scope.referenced) self.break_uses.append([False, break_var]) node = self.generic_visit(node) if self.break_uses[-1][0]: template = """ var_name = False while original_test and not var_name: original_body else: original_orelse """ node = templates.replace(template, var_name=break_var, original_test=node.test, original_body=node.body, original_orelse=node.orelse) self.break_uses.pop() return node
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace(template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.referenced: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO (mdan): Implement this properly. id:486 # https://github.com/imdone/tensorflow/issues/487 # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_loop( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def make_update_list_node(self, list_, elt): return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0]
def generate_Print(self): return templates.replace('print(x)', x=self.generate_expression())[0]
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain # a Lambda node instead of a FunctionDef if o.__name__ == '<lambda>': raise NotImplementedError( 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if program_ctx.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr( candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns
def _create_break_check(self): template = """ (not var_name) """ expr, = templates.replace(template, var_name=self.break_uses[-1][1]) return expr.value
def _create_break_init(self): template = """ var_name = False """ assign, = templates.replace(template, var_name=self.break_uses[-1][1]) return assign
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) body_defs = body_scope.created | body_scope.modified orelse_defs = orelse_scope.created | orelse_scope.modified live = anno.getanno(node, 'live_out') # We'll need to check if we're closing over variables that are defined # elsewhere in the function # NOTE: we can only detect syntactic closure in the scope # of the code passed in. If the AutoGraph'd function itself closes # over other variables, this analysis won't take that into account. defined = anno.getanno(node, 'defined_in') # We only need to return variables that are # - modified by one or both branches # - live (or has a live parent) at the end of the conditional modified = [] for def_ in body_defs | orelse_defs: def_with_parents = set((def_, )) | def_.support_set if live & def_with_parents: modified.append(def_) # We need to check if live created variables are balanced # in both branches created = live & (body_scope.created | orelse_scope.created) # The if statement is illegal if there are variables that are created, # that are also live, but both branches don't create them. if created: if created != (body_scope.created & live): raise ValueError( 'The main branch does not create all live symbols that the else ' 'branch does.') if created != (orelse_scope.created & live): raise ValueError( 'The else branch does not create all live symbols that the main ' 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( self.context.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.context.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): """Builds list of return variables for a branch of a conditional.""" returns = [] for s in modified: if s in aliased_names: returns.append(alias_map[s]) else: if s not in scope.created | defined: raise ValueError( 'Attempting to return variable "%s" from the true branch of ' 'a conditional, but it was not closed over, or created in ' 'this branch.' % str(s)) else: returns.append(s) return tuple(returns) body_returns = build_returns(aliased_body_orig_names, alias_body_map, body_scope) orelse_returns = build_returns(aliased_orelse_orig_names, alias_orelse_map, orelse_scope) else: body_returns = orelse_returns = templates.replace( 'tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_body_orig_names), aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orelse_orig_names), aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def _convert_builtin(self, node): template = """ autograph_utils.dynamic_builtin(func, args) """ return templates.replace(template, func=node.func, args=node.args)[0].value
def _convert_print(self, node): template = """ autograph_utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.') if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.') modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', all_referenced) orelse_name = self.context.namer.new_symbol('if_false', all_referenced) if modified: body_returns = tuple( alias_map[s] if s in aliased_orig_names else s for s in modified) else: body_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_orelse, returns=body_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. # In addition, avoid renaming well-known names. # TODO(mdan): Move these names into config. unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) guarded_args = tuple(s for s in args_scope.used if not s.is_composite() and s not in unguarded_names) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple( s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.context.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with ag__.utils.control_dependency_on_returns(call): aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node