def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_cond'): extra_cond = anno.getanno(node, 'extra_cond') extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) else: extra_cond = parser.parse_expression('True') template = """ def extra_cond_name(state_ssf): return extra_cond_expr def body_name(iterate, state_ssf): body return state_ssf, state_ast_tuple = ag__.for_loop( iterated, extra_cond_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iterated=node.iter, iterate=node.target, extra_cond_name=self.context.namer.new_symbol('extra_cond', all_referenced), extra_cond_expr=extra_cond, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_cond'): extra_cond = anno.getanno(node, 'extra_cond') extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) else: extra_cond = parser.parse_expression('True') template = """ def extra_cond_name(state_ssf): return extra_cond_expr def body_name(iterate, state_ssf): body return state_ssf, state_ast_tuple = __ops.for_loop( iterated, extra_cond_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iterated=node.iter, iterate=node.target, extra_cond_name=self.context.namer.new_symbol('extra_cond', all_referenced), extra_cond_expr=extra_cond, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b')
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source, _ = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute( ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN(qual_names.QN('b'), attr='c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN(qual_names.QN('b'), attr='c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.referenced: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO (mdan): Implement this properly. id:486 # https://github.com/imdone/tensorflow/issues/487 # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_loop( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.') if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.') modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', all_referenced) orelse_name = self.context.namer.new_symbol('if_false', all_referenced) if modified: body_returns = tuple( alias_map[s] if s in aliased_orig_names else s for s in modified) else: body_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_orelse, returns=body_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) body_defs = body_scope.created | body_scope.modified orelse_defs = orelse_scope.created | orelse_scope.modified live = anno.getanno(node, 'live_out') # We'll need to check if we're closing over variables that are defined # elsewhere in the function # NOTE: we can only detect syntactic closure in the scope # of the code passed in. If the AutoGraph'd function itself closes # over other variables, this analysis won't take that into account. defined = anno.getanno(node, 'defined_in') # We only need to return variables that are # - modified by one or both branches # - live (or has a live parent) at the end of the conditional modified = [] for def_ in body_defs | orelse_defs: def_with_parents = set((def_, )) | def_.support_set if live & def_with_parents: modified.append(def_) # We need to check if live created variables are balanced # in both branches created = live & (body_scope.created | orelse_scope.created) # The if statement is illegal if there are variables that are created, # that are also live, but both branches don't create them. if created: if created != (body_scope.created & live): raise ValueError( 'The main branch does not create all live symbols that the else ' 'branch does.') if created != (orelse_scope.created & live): raise ValueError( 'The else branch does not create all live symbols that the main ' 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( self.context.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.context.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): """Builds list of return variables for a branch of a conditional.""" returns = [] for s in modified: if s in aliased_names: returns.append(alias_map[s]) else: if s not in scope.created | defined: raise ValueError( 'Attempting to return variable "%s" from the true branch of ' 'a conditional, but it was not closed over, or created in ' 'this branch.' % str(s)) else: returns.append(s) return tuple(returns) body_returns = build_returns(aliased_body_orig_names, alias_body_map, body_scope) orelse_returns = build_returns(aliased_orelse_orig_names, alias_orelse_map, orelse_scope) else: body_returns = orelse_returns = templates.replace( 'tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_body_orig_names), aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orelse_orig_names), aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if any(owner in live_out for owner in s.owner_set): returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbol_list(created_in_body), self._fmt_symbol_list(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: # TODO(mdan): Move this quirk into the operator implementation. cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = templates.replace_as_expression('1') returned_from_orelse = templates.replace_as_expression('1') body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] for base in c.__bases__: if isinstance(object, base): bases.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[])) node = gast.Module(output_nodes) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. node = qual_names.resolve(node) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) node = ast_util.rename_symbols(node, renames) return node, class_name, class_namespace
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c, rewrite_errors=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.referenced: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) body_defs = body_scope.created | body_scope.modified orelse_defs = orelse_scope.created | orelse_scope.modified live = anno.getanno(node, 'live_out') # We'll need to check if we're closing over variables that are defined # elsewhere in the function # NOTE: we can only detect syntactic closure in the scope # of the code passed in. If the AutoGraph'd function itself closes # over other variables, this analysis won't take that into account. defined = anno.getanno(node, 'defined_in') # We only need to return variables that are # - modified by one or both branches # - live (or has a live parent) at the end of the conditional modified = [] for def_ in body_defs | orelse_defs: def_with_parents = set((def_,)) | def_.support_set if live & def_with_parents: modified.append(def_) # We need to check if live created variables are balanced # in both branches created = live & (body_scope.created | orelse_scope.created) # The if statement is illegal if there are variables that are created, # that are also live, but both branches don't create them. if created: if created != (body_scope.created & live): raise ValueError( 'The main branch does not create all live symbols that the else ' 'branch does.') if created != (orelse_scope.created & live): raise ValueError( 'The else branch does not create all live symbols that the main ' 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( self.context.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.context.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): """Builds list of return variables for a branch of a conditional.""" returns = [] for s in modified: if s in aliased_names: returns.append(alias_map[s]) else: if s not in scope.created | defined: raise ValueError( 'Attempting to return variable "%s" from the true branch of ' 'a conditional, but it was not closed over, or created in ' 'this branch.' % str(s)) else: returns.append(s) return tuple(returns) body_returns = build_returns(aliased_body_orig_names, alias_body_map, body_scope) orelse_returns = build_returns(aliased_orelse_orig_names, alias_orelse_map, orelse_scope) else: body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_body_orig_names), aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orelse_orig_names), aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if any(owner in live_out for owner in s.owner_set): returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbol_list(created_in_body), self._fmt_symbol_list(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: # TODO(mdan): Move this quirk into the operator implementation. cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = templates.replace_as_expression( 'tf.constant(1)') returned_from_orelse = templates.replace_as_expression( 'tf.constant(1)') body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr