def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b')
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def _check_anno_matches(self, node, anno_name, var_names): if isinstance(var_names, str): var_names = (var_names,) qual_vars = set() for var_name in var_names: if isinstance(var_name, str): if '[' in var_name or ']' in var_name: raise ValueError('Annotation matching not supported with subscript.') if '.' not in var_name: qual_vars.add(qual_names.QN(var_name)) else: attrs = var_name.split('.') this_qn = functools.reduce(qual_names.QN, attrs[1:], qual_names.QN(attrs[0])) qual_vars.add(this_qn) self.assertEqual(anno.getanno(node, anno_name), qual_vars)
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source, _ = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def visit_FunctionDef(self, node): if self.scope: qn = qual_names.QN(node.name) self.scope.mark_write(qn) current_scope = self.scope body_scope = Scope(current_scope, isolated=True) self.scope = body_scope self.generic_visit(node) anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) self.scope = current_scope return node
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN(qual_names.QN('b'), attr='c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def visit_FunctionDef(self, node): # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompany it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) self.scope.mark_write(qual_names.QN(node.name)) anno.setanno(node, anno.Static.SCOPE, self.scope) self._exit_scope() # A separate Scope tracks the actual function definition. self._enter_scope(True) node.args = self.visit(node.args) # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. self._enter_scope(False) node.body = self.visit_block(node.body) anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) self._exit_scope() self._exit_scope() return node
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] for base in c.__bases__: if isinstance(object, base): bases.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[])) node = gast.Module(output_nodes) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. node = qual_names.resolve(node) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) node = ast_util.rename_symbols(node, renames) return node, class_name, class_namespace
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. # In addition, avoid renaming well-known names. # TODO(mdan): Move these names into config. unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) guarded_args = tuple(s for s in args_scope.used if not s.is_composite() and s not in unguarded_names) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple( s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.context.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with ag__.utils.control_dependency_on_returns(call): aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node