def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) has_extra_test = anno.hasanno(node, 'extra_test') if loop_state: if has_extra_test: # Loop with early stopping (e.g. break or return) extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved_symbols) node = self._create_for_loop_early_stopping( loop_state, state_ssf, state_ast_tuple, node, extra_test_name, extra_test, body_name, node_body) else: # Loop with loop-carried state and no early stopping node = self._create_for_loop_with_state( loop_state, state_ssf, state_ast_tuple, node, body_name, node_body) else: # Loop with no loop-carried state and no early stopping assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) ' 'should create state variables.') node = self._create_for_loop_without_state(node, body_name, node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) has_extra_test = anno.hasanno(node, 'extra_test') if loop_state: if has_extra_test: # Loop with early stopping (e.g. break or return) extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved_symbols) loop_nodes = self._for_loop_with_extra_test( loop_state, state_ssf, state_ast_tuple, node, extra_test_name, extra_test, body_name, node_body) else: # Loop with loop-carried state and no early stopping loop_nodes = self._for_loop_with_state( loop_state, state_ssf, state_ast_tuple, node, body_name, node_body) else: # Loop with no loop-carried state and no early stopping assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) ' 'should create state variables.') loop_nodes = self._for_loop_without_state(node, body_name, node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + loop_nodes
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace(template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace(template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace( template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) # Note: one might expect we can dispatch based on the loop condition. # But because that is dependent on the state, it cannot be evaluated ahead # of time - doing that would risk duplicating any effects the condition has. # Furthermore, we cannot evaluate slices and attributes, because they might # trigger __getitem__ or __getattribute__. # # A case where this fails includes ops with side effects on a stateful # resource captured in an object: # # while self.v.read() > 0: # self.v.assign(1) # # TODO(mdan): Handle the case above. cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.used: cond_closure.update(s.support_set) cond_closure -= loop_state loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) # Note: one might expect we can dispatch based on the loop condition. # But because that is dependent on the state, it cannot be evaluated ahead # of time - doing that would risk duplicating any effects the condition has. # Furthermore, we cannot evaluate slices and attributes, because they might # trigger __getitem__ or __getattribute__. # # A case where this fails includes ops with side effects on a stateful # resource captured in an object: # # while self.v.read() > 0: # self.v.assign(1) # # TODO(mdan): Handle the case above. cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.read: cond_closure.update(s.support_set) cond_closure -= loop_state loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) if loop_state: template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def test_name(): return test def body_name(): body return () ag__.while_stmt(test_name, body_name, ()) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple, original_node, extra_test_name, extra_test, body_name, loop_body, ssf_map): target_nodes = ast_util.rename_symbols(original_node.target, ssf_map) template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 target = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ return templates.replace(template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=original_node.iter, target=target_nodes, extra_test_name=extra_test_name, extra_test_expr=extra_test, body_name=body_name, body=loop_body)
def test_rename_symbols_function(self): node = parser.parse('def f():\n pass') node = ast_util.rename_symbols( node, {qual_names.QN('f'): qual_names.QN('f1')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'def f1():\n pass')
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError( 'Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple, original_node, extra_test_name, extra_test, body_name, loop_body, ssf_map): target_nodes = ast_util.rename_symbols(original_node.target, ssf_map) template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 target = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ return templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=original_node.iter, target=target_nodes, extra_test_name=extra_test_name, extra_test_expr=extra_test, body_name=body_name, body=loop_body)
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) if loop_state: template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def test_name(): return test def body_name(): body return () ag__.while_stmt(test_name, body_name, ()) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_attributes(self): node = parser.parse('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def test_rename_symbols_global(self): node = parser.parse('global a, b, c') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b'): qual_names.QN('renamed_b')}) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'global a, renamed_b, c')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.value.left.id, str) source = parser.unparse(node, include_encoding_marker=False) self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def test_rename_symbols_basic(self): node = parser.parse('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) source = parser.unparse(node, include_encoding_marker=False) expected_node_src = 'renamed_a + b' self.assertIsInstance(node.value.left.id, str) self.assertAstMatches(node, source) self.assertAstMatches(node, expected_node_src)
def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() composites = set() for s in modified_in_cond: if s in live_out and not s.is_composite(): returned_from_cond.add(s) if s.is_composite(): # Special treatment for compound objects, always return them. # This allows special handling within the if_stmt itself. # For example, in TensorFlow we need to restore the state of composite # symbols to ensure that only effects from the executed branch are seen. composites.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple(s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple(s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) all_referenced = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) composite_defs = self._create_state_functions(composites, state_getter_name, state_setter_name) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name, state_getter_name, state_setter_name) if_ast = (undefined_assigns + composite_defs + body_def + orelse_def + cond_assign + cond_expr) return if_ast
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue nodes, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, do_rename=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = nodes[0] namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple( s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple( s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name) return (undefined_assigns + cond_assign + body_def + orelse_def + cond_expr)
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbols(created_in_body), self._fmt_symbols(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = (templates.replace_as_expression('tf.constant(1)'),) returned_from_orelse = ( templates.replace_as_expression('tf.constant(1)'),) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() for s in modified_in_cond: if s in live_out: returned_from_cond.add(s) elif s.is_composite(): # Special treatment for compound objects: if any of their owner entities # are live, then they are outputs as well. if live_out & s.owner_set: returned_from_cond.add(s) need_alias_in_body = body_scope.modified & defined_in need_alias_in_orelse = orelse_scope.modified & defined_in created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in if created_in_body != created_in_orelse: raise ValueError( 'if statement may not initialize all variables: the true branch' ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % (self._fmt_symbols(created_in_body), self._fmt_symbols(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name) return cond_assign + body_def + orelse_def + cond_expr
def visit_While(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.used: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node, ), _, entity_info = convert_func_to_ast(m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset( entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo(source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def convert_class_to_ast(c, program_ctx): """Specialization of `convert_entity_to_ast` for classes.""" # TODO(mdan): Revisit this altogether. Not sure we still need it. converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('cannot convert %s: no member methods' % c) # TODO(mdan): Don't clobber namespaces for each method in one class namespace. # The assumption that one namespace suffices for all methods only holds if # all methods were defined in the same module. # If, instead, functions are imported from multiple modules and then spliced # into the class, then each function has its own globals and __future__ # imports that need to stay separate. # For example, C's methods could both have `global x` statements referring to # mod1.x and mod2.x, but using one namespace for C would cause a conflict. # from mod1 import f1 # from mod2 import f2 # class C(object): # method1 = f1 # method2 = f2 class_namespace = {} future_features = None for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue (node,), _, entity_info = convert_func_to_ast( m, program_ctx=program_ctx, do_rename=False) class_namespace.update(entity_info.namespace) converted_members[m] = node # TODO(mdan): Similarly check the globals. if future_features is None: future_features = entity_info.future_features elif frozenset(future_features) ^ frozenset(entity_info.future_features): # Note: we can support this case if ever needed. raise ValueError( 'cannot convert {}: if has methods built with mismatched future' ' features: {} and {}'.format(c, future_features, entity_info.future_features)) namer = naming.Namer(class_namespace) class_name = namer.class_name(c.__name__) # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: raise NotImplementedError( 'Conversion of classes that do not directly extend classes from' ' whitelisted modules is temporarily suspended. If this breaks' ' existing code please notify the AutoGraph team immediately.') base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef( class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) # TODO(mdan): Find a way better than forging this object. entity_info = transformer.EntityInfo( source_code=None, source_file=None, future_features=future_features, namespace=class_namespace) return output_nodes, class_name, entity_info
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} base_names = [] for base in c.__bases__: if isinstance(object, base): base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. bases = [gast.Name(n, gast.Load(), None) for n in base_names] class_def = gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) class_def = ast_util.rename_symbols(class_def, renames) output_nodes.append(class_def) return output_nodes, class_name, class_namespace
def visit_While(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.used: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) return node