def test_if_subscripts(self): def test_fn(a, b, c, e): if a > 0: a[b] = -a[c] d = 2 * a else: a[0] = e d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), ('a[b]', 'd'), ('d',), ) # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), ('a[0]', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]'), ('a', 'b', 'c', 'd', 'e'), )
def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z')) self.assertScopeIs( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u')) self.assertScopeIs( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_resolve(self): source = """ def test_fn(x): '''Docstring.''' return x # comment """ source = textwrap.dedent(source) node = parser.parse_str(source) origin_info.resolve(node, source) origin = anno.getanno(node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, " '''Docstring.'''") self.assertIsNone(origin.comment) origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 4) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def visit_For(self, node): node.target = self.visit(node.target) node.body = self._process_block( anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) node.orelse = self._process_block( anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) return node
def _get_loop_state(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) reserved_symbols = body_scope.referenced # Note that it doesn't matter whether the variables are live after the loop. # If the loop modifies them nonlocally (e.g. the result of an iteration # depends on the previous iteration), then they need to be included in # the loop state, regardless of whether they are later used or not. loop_state = body_scope.modified & live_in undefined_lives = loop_state - defined_in # Only simple variables must be defined. The composite ones will be # implicitly checked at runtime. undefined_simple_lives = {v for v in undefined_lives if v.is_simple()} if undefined_simple_lives: raise NameError( 'cannot convert loop: it includes symbols that are undefined' ' when entering the loop: {}'.format( self._fmt_symbols(undefined_simple_lives))) live_defs_in_loop = (body_scope.modified - live_in) & live_out if live_defs_in_loop: # TODO(mdan): Include reference to explanation why. raise NotImplementedError( 'cannot convert loop: it includes symbols that are defined' ' inside the loop, but used later: {}. To fix, initialize' ' these symbols before the loop'.format( self._fmt_symbols(live_defs_in_loop))) return loop_state, reserved_symbols
def test_if_attributes(self): def test_fn(a): if a > 0: a.b = -a.c d = 2 * a else: a.b = a.c d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c'), ('a.b', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('a', 'a.c', 'd'), ('a.b', 'd'), ('a', 'd'), )
def test_origin_info_preserved_in_moved_nodes(self): class TestTransformer(transformer.Base): def visit_If(self, node): return node.body tr = TestTransformer(self._simple_context()) def test_fn(): x = 1 if x > 0: x = 1 x += 3 return x node, source = parser.parse_entity(test_fn, future_features=()) origin_info.resolve(node, source) node = tr.visit(node) assign_node = node.body[1] aug_assign_node = node.body[2] self.assertEqual( anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4) self.assertEqual( anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace_as_expression( 'func_name', func_name=new_name) return node
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_context()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) test_function_node = node test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual( (test_function_node,), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def _validate_no_live_vars_created(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) live_vars_created_in_body = live_vars_out & body_scope.created if live_vars_created_in_body: raise ValueError( 'The following variables are created inside the loop and used later:' '\n%s\n' 'Variables must be declared outside loops because loops may not' ' necessarily execute.' % self._fmt_symbol_list( live_vars_created_in_body))
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = self._parse_and_analyze(test_fn, {'training': training}) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): if self._should_compile(node, target_fqn): node = self._rename_compilable_function(node) else: node = self.generic_visit(node) return node elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) elif inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif inspect_utils.isnamedtuple(target_entity): # Although not compilable, we assume they are safe for graph mode. node = self.generic_visit(node) return node else: # TODO(mdan): Instert dynamic conversion here instead. raise NotImplementedError( 'py_func with return values (unknown function)') else: # Special cases # TODO(mdan): These need a systematic review - there may be more. # 1. super() calls - these are preserved. The class conversion mechanism # will ensure that they return the correct value. if ast_util.matches(node, 'super(_)'): return node # 2. super().method calls - these are preserved as well, when the # conversion processes the entire class. if (ast_util.matches(node, 'super(_)._(_)') and self.ctx.info.owner_type is not None): return node node = self._insert_dynamic_conversion(node) return node
def test_params(self): def test_fn(a, b): # pylint: disable=unused-argument return b node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) self.assertScopeIs(body_scope, ('b',), ()) self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b')) args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE) self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
def test_primitive_values(self): a = None def test_fn(): return a node = self._parse_and_analyze(test_fn, {'a': True}) retval_node = node.body[0].body[0].value if six.PY2: self.assertEqual( anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) else: self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.options.strip_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.options.recursive: node = self._insert_dynamic_conversion(node) return node
def create_source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: nodes: Iterable[ast.AST, ...] code: Text filename: Optional[Text] indices_in_code: Union[int, Iterable[int, ...]], the positions at which nodes appear in code. The parser always returns a module when parsing code. This argument indicates the position in that module's body at which the corresponding of node should appear. Returns: Dict[CodeLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse_str(code) reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] resolve(reparsed_nodes, code) result = {} for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue line_loc = LineLocation(filename, final_info.loc.lineno) existing_origin = result.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue result[line_loc] = origin_info return result
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = self._parse_and_analyze(test_fn, {'session': session}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) method_call = node.body[0].body[0].body[0].value.func self.assertEquals(session.Session.run, anno.getanno(method_call, 'live_val'))
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(other_value).strip(), compiler.ast_to_source(first_value).strip())) return first_value
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def test_nested_if(self): def test_fn(b): if b > 0: if b < 5: a = b else: a = b * b return a node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIs( anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',)) self.assertScopeIs( anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',))
def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = tf.constant(False) while test and not var_name: body else: orelse """ node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) return node
def test_for(self): def test_fn(a): b = a for _ in a: c = b b -= 1 return b, c node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c')) self.assertScopeIs( anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c', '_'))
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) processing_expr_node = False parent_origin = self.ctx.current_origin if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True elif isinstance(node, gast.Expr): processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) if processing_expr_node: entry_expr_value = node.value if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) self.ctx.current_origin = parent_origin # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if processing_expr_node: if isinstance(result, gast.Expr) and result.value != entry_expr_value: # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced node. if result is not node and result is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() if local_scope_size_at_entry != len(self._local_scope_state): raise AssertionError( 'Inconsistent local scope stack. Before entering node %s, the' ' stack had length %d, after exit it has length %d. This' ' indicates enter_local_scope and exit_local_scope are not' ' well paired.' % (node, local_scope_size_at_entry, len(self._local_scope_state))) return result
def assertClosureTypes(self, node, expected): actual = anno.getanno(node, anno.Static.CLOSURE_TYPES) actual = {str(k): v for k, v in actual.items()} for k, v in expected.items(): self.assertIn(k, actual) self.assertEqual(actual[k], v)
def visit_FunctionDef(self, node): node.args = self.generic_visit(node.args) node.decorator_list = self.visit_block(node.decorator_list) node.body = self._process_block(anno.getanno(node, anno.Static.SCOPE), node.body) return node
def from_str(qn_str): node = parser.parse_expression(qn_str) node = resolve(node) return anno.getanno(node, anno.Basic.QN)
def visit_Call(self, node): full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) function_context_name = self.state[_Function].context_name node = self.generic_visit(node) # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). if full_name.startswith('ag__.'): return node # Calls to the function context manager (inserted by function_scopes) are # also safe. if full_name.startswith(function_context_name + '.'): return node # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. # TODO(mdan): Generalize this to a "static whitelist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): global set_trace_warned if not set_trace_warned: # TODO(mdan): Update and shorten once available on tensorflow.org. ag_logging.warn( 'Detected `pdb.set_trace()` in converted code. The code' ' generated by AutoGraph is not optimized for step-by-step' ' debugging. See https://github.com/tensorflow/tensorflow/' 'blob/master/tensorflow/python/autograph/g3doc/reference/' 'debugging.md.') set_trace_warned = True return node if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return node func = node.func starred_arg = None normal_args = [] for a in node.args: if isinstance(a, gast.Starred): assert starred_arg is None, 'Multiple *args should be impossible.' starred_arg = a else: normal_args.append(a) if starred_arg is None: args = templates.replace_as_expression('(args,)', args=normal_args) else: args = templates.replace_as_expression('(args,) + tuple(stararg)', stararg=starred_arg.value, args=normal_args) kwargs_arg = None normal_keywords = [] for k in node.keywords: if k.arg is None: assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' kwargs_arg = k else: normal_keywords.append(k) if kwargs_arg is None: if not normal_keywords: kwargs = parser.parse_expression('None') else: kwargs = ast_util.keywords_to_dict(normal_keywords) else: kwargs = templates.replace_as_expression( 'dict(kwargs, **keywords)', kwargs=kwargs_arg.value, keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ ag__.converted_call(func, options, args, kwargs) """ new_call = templates.replace_as_expression( template, func=func, options=parser.parse_expression(function_context_name + '.callopts'), args=args, kwargs=kwargs) return new_call
def assertDifferentAnno(self, first, second, key): self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undef = self._get_loop_state( node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace(template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace(template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) undefined_assigns = self._create_undefined_assigns(possibly_undef) return undefined_assigns + node
def assertQNStringIs(self, node, qn_str): self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
def res_call(self, ns, types_ns, node, f_type, args, keywords): test_self.assertEqual(f_type, (str, )) test_self.assertEqual(anno.getanno(node.func, anno.Basic.QN), qual_names.QN('g')) return {float}, None
def assertNotSameDef(self, first, second): self.assertHasDefs(first, 1) self.assertHasDefs(second, 1) self.assertIsNot( anno.getanno(first, anno.Static.DEFINITIONS)[0], anno.getanno(second, anno.Static.DEFINITIONS)[0])
def assertHasDefs(self, node, num): defs = anno.getanno(node, anno.Static.DEFINITIONS) self.assertEqual(len(defs), num) for r in defs: self.assertIsInstance(r, reaching_definitions.Definition)
def visit_With(self, node): node.items = self.visit_block(node.items) node.body = self._process_block( anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) return node
def assertClosureTypes(self, node, expected): actual = anno.getanno(node, anno.Static.CLOSURE_TYPES) actual = {str(k): v for k, v in actual.items()} self.assertDictEqual(actual, expected)
def visit_For(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(loop_vars): return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_name=extra_test_name, loop_vars=loop_vars, extra_test_expr=extra_test) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # Workaround for PEP-3113 # iterates_var holds a single variable with the iterates, which may be a # tuple. iterates_var_name = self.ctx.namer.new_symbol('iterates', reserved_symbols) template = """ iterates = iterates_var_name """ iterate_expansion = templates.replace( template, iterates=node.target, iterates_var_name=iterates_var_name) undefined_assigns = self._create_undefined_assigns(possibly_undefs) basic_symbol_names = tuple( gast.Str(str(symbol)) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ undefined_assigns state_functions def body_name(iterates_var_name, loop_vars): iterate_expansion body return loop_vars, extra_test_function loop_vars_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ undefined_assigns state_functions def body_name(iterates_var_name): iterate_expansion body return () extra_test_function ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts)
def assertSameAnno(self, first, second, key): self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
def visit_Call(self, node): # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): return self.generic_visit(node) if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return self.generic_visit(node) if isinstance(node.func, gast.Attribute): func = gast.Str(node.func.attr) owner = node.func.value else: func = node.func owner = parser.parse_expression('None') starred_arg = None normal_args = [] for a in node.args: if isinstance(a, gast.Starred): assert starred_arg is None, 'Multiple *args should be impossible.' starred_arg = a else: a = self.visit(a) normal_args.append(a) if starred_arg is None: args = templates.replace_as_expression('(args,)', args=normal_args) else: args = templates.replace_as_expression('(args,) + tuple(stararg)', stararg=starred_arg.value, args=normal_args) kwargs_arg = None normal_keywords = [] for k in node.keywords: if k.arg is None: assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' kwargs_arg = k else: k = self.visit(k) normal_keywords.append(k) if kwargs_arg is None: if not normal_keywords: kwargs = parser.parse_expression('None') else: kwargs = ast_util.keywords_to_dict(normal_keywords) else: kwargs = templates.replace_as_expression( 'dict(kwargs, **keywords)', kwargs=kwargs_arg.value, keywords=ast_util.keywords_to_dict(normal_keywords)) template = """ ag__.converted_call(func, owner, options, args, kwargs) """ new_call = templates.replace_as_expression( template, func=func, owner=owner, options=self.ctx.program.options.to_ast( internal_convert_user_code=self.ctx.program.options.recursive), args=args, kwargs=kwargs) return new_call
def _live_tensors(f, attr_name="inputs"): """Returns the indices of the used inputs. Note: This currently only handles direct index accesses e.g. op.inputs[1]. If the function has slicing or list comprehension on attr_name then returns _ALL. This ensure that this is correct even if inefficient. Args: f: A grad function, taking the op as first argument. attr_name: op attr to track. "inputs" or "outputs". Returns: Either one of: * set of integers representing individual indices of inputs used * the value _ALL, if indices are used but cannot be determined which * empty set, if no inputs are used """ node, _ = parser.parse_entity(f, ()) entity_info = transformer.EntityInfo( name=f.__name__, source_code=None, source_file=None, future_features=(), namespace=sys.modules[f.__module__].__dict__) ctx = transformer.Context(entity_info, None, None) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_fndefs.resolve(node, ctx, graphs) node = liveness.resolve(node, ctx, graphs) op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN) op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name) special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, )) node = special_tracker.visit(node) live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN) inputs_outputs_used_qns = set() for v in special_tracker.complex_reads: # Complicated patterns like op.inputs[:3]. Could be smarter about them # if they matter much. if v == op_inputs_outputs_name: return _ALL for v in live_vars_in: if v in special_tracker.reads: if (v.has_subscript() and v.parent == op_inputs_outputs_name): inputs_outputs_used_qns.add(v) elif v == op_inputs_outputs_name: # When op.{attr_name} is used directly, assume all tensors are # used for now. In that case, no point digging further. # TODO(mdan): We can descend into tuple expansions. return _ALL function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name) node = function_calls_tracker.visit(node) input_output_indices = set() for called_f in function_calls_tracker.calls: child_indices = _live_tensors(called_f, attr_name=attr_name) if child_indices is _ALL: return _ALL input_output_indices |= child_indices for v in inputs_outputs_used_qns: assert v.has_subscript() _, subscript = v.qn if not subscript.is_simple(): # Not a number, assuming it can be anything. return _ALL subscript_val, = subscript.qn if (not isinstance(subscript_val, qual_names.Literal) and not isinstance(subscript_val.value, int)): # Not a number, assuming it can be anything. return _ALL input_output_indices.add(subscript_val.value) return input_output_indices
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) # Note: one might expect we can dispatch based on the loop condition. # But because that is dependent on the state, it cannot be evaluated ahead # of time - doing that would risk duplicating any effects the condition has. # Furthermore, we cannot evaluate slices and attributes, because they might # trigger __getitem__ or __getattribute__. # # A case where this fails includes ops with side effects on a stateful # resource captured in an object: # # while self.v.read() > 0: # self.v.assign(1) # # TODO(mdan): Handle the case above. cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.read: cond_closure |= s.support_set loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) if loop_state: template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) else: template = """ def test_name(): return test def body_name(): body return () ag__.while_stmt(test_name, body_name, (), (extra_deps,)) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_While(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Str(str(symbol)) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ state_functions def body_name(loop_vars): body return loop_vars, def test_name(loop_vars): return test loop_vars_ast_tuple = ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ node = templates.replace( template, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ state_functions def body_name(): body return () def test_name(): return test ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() composites = set() for s in modified_in_cond: if s in live_out and not s.is_composite(): returned_from_cond.add(s) if s.is_composite(): # Special treatment for compound objects, always return them. # This allows special handling within the if_stmt itself. # For example, in TensorFlow we need to restore the state of composite # symbols to ensure that only effects from the executed branch are seen. composites.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple( s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple( s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) all_referenced = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) returned_from_cond = tuple(returned_from_cond) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name),) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) composite_defs = self._create_state_functions( composites, state_getter_name, state_setter_name) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name, state_getter_name, state_setter_name) return (undefined_assigns + cond_assign + composite_defs + body_def + orelse_def + cond_expr)
def create_source_map(nodes, code, filepath): """Creates a source map between an annotated AST and the code it compiles to. Note: this function assumes nodes nodes, code and filepath correspond to the same code. Args: nodes: Iterable[ast.AST, ...], one or more AST modes. code: Text, the source code in which nodes are found. filepath: Text Returns: Dict[LineLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False) for node in reparsed_nodes: resolve(node, code, filepath, node.lineno, node.col_offset) source_map = {} try: for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue # Note: the keys are by line only, excluding the column offset. line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno) existing_origin = source_map.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of column overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue source_map[line_loc] = origin_info except ValueError as err: new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n' new_msg += str(err) new_msg += 'Diff:\n' for n, rn in zip(nodes, reparsed_nodes): nodes_str = pretty_printer.fmt(n, color=False, noanno=True) reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) diff = difflib.context_diff(nodes_str.split('\n'), reparsed_nodes_str.split('\n'), fromfile='Original nodes', tofile='Reparsed nodes', n=7) diff = '\n'.join(diff) new_msg += diff + '\n' raise ValueError(new_msg) return source_map
def visit_FunctionDef(self, node): with self.state[_Function] as fn: fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) return self.generic_visit(node)
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) processing_expr_node = False parent_origin = self.ctx.current_origin try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True elif isinstance(node, gast.Expr): processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) if processing_expr_node: entry_expr_value = node.value if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) self.ctx.current_origin = parent_origin # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if processing_expr_node: if isinstance(result, gast.Expr) and result.value != entry_expr_value: # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() if local_scope_size_at_entry != len(self._local_scope_state): raise AssertionError( 'Inconsistent local scope stack. Before entering node %s, the' ' stack had length %d, after exit it has length %d. This' ' indicates enter_local_scope and exit_local_scope are not' ' well paired.' % (node, local_scope_size_at_entry, len(self._local_scope_state))) return result except (ValueError, AttributeError, KeyError, NotImplementedError) as e: if not self.ctx.current_origin: raise e original_file_path = self.ctx.current_origin.loc.filename original_line_number = self.ctx.current_origin.loc.lineno original_col_offset = self.ctx.current_origin.loc.col_offset original_source_line = ( self.ctx.current_origin.source_code_line) msg = '%s: %s.' % (e.__class__.__name__, str(e)) # TODO(mdan): Avoid the printing of the original exception. # In other words, we need to find how to suppress the "During handling # of the above exception, another exception occurred" message. six.reraise( AutoGraphParseError, AutoGraphParseError(msg, (original_file_path, original_line_number, original_col_offset, original_source_line)), sys.exc_info()[2])
def visit_For(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified | iter_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) opts.keys.append(gast.Constant('iterate_names', kind=None)) opts.values.append( gast.Constant(parser.unparse(node.target, include_encoding_marker=False), kind=None)) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(): nonlocal_declarations return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_expr=extra_test, extra_test_name=extra_test_name, loop_vars=loop_vars, nonlocal_declarations=nonlocal_declarations) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) template = """ iterates = iterate_arg_name """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) template = """ state_functions def body_name(iterate_arg_name): nonlocal_declarations iterate_expansion body extra_test_function undefined_assigns ag__.for_stmt( iterated, extra_test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, iterate_expansion=iterate_expansion, iterated=node.iter, nonlocal_declarations=nonlocal_declarations, opts=opts, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns)
def visit_Lambda(self, node): with self.state[_Function] as fn: fn.scope = anno.getanno(node, anno.Static.SCOPE) return self.generic_visit(node)
def assertTypes(self, node, expected): if not isinstance(expected, tuple): expected = expected, self.assertSetEqual(set(anno.getanno(node, anno.Static.TYPES)), set(expected))
def visit_Expr(self, node): self.generic_visit(node) if isinstance(node.value, gast.Call): # Patterns of single function calls, like: # opt.minimize(loss) # or: # tf.py_func(...) # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) # NOTE: We can't guard object attributes because they may not be writable. # In addition, avoid renaming well-known names. # TODO(mdan): Move these names into config. unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) guarded_args = tuple( s for s in args_scope.used if not s.is_composite() and s not in unguarded_names) # TODO(mdan): Include all arguments which depended on guarded_args too. # For example, the following will still cause a race: # tf.assign(a, a + 1) # b = a + 1 # tf.assign(a, a + 1) # Control deps here should include `b` # c = b + 1 # Or maybe we should just raise an "unsafe assign" error? if guarded_args: # The aliases may need new names to avoid incorrectly making them local. # TODO(mdan): This is brutal. It will even rename modules - any fix? need_alias = tuple(s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: s, = guarded_args aliased_guarded_args = alias_map.get(s, s) else: aliased_guarded_args = gast.Tuple( [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ with ag__.utils.control_dependency_on_returns(call): aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, call=node.value, aliased_guarded_args=aliased_guarded_args, guarded_args=guarded_args)[-1] else: alias_map = {} template = """ with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] control_deps_guard.body = [] node = control_deps_guard anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, (node.body, alias_map)) return node
def visit_FunctionDef(self, node): self.generic_visit(node) kept_decorators = [] for dec in node.decorator_list: if isinstance(dec, gast.Call): dec_func = dec.func else: dec_func = dec # Special cases. # TODO(mdan): Is there any way we can treat these more generically? # We may want to forego using decorators altogether if we can't # properly support them. if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod', ): # Assumption: decorators are only visible in the AST when converting # a function inline (via another decorator). # In that case, the converted function is no longer part of the # original object that it was declared into. # This is currently verified by tests. continue original_dec = anno.getanno(dec_func, anno.Basic.QN) dec_value = anno.getanno(dec_func, 'live_val') if dec_value in self.ctx.program.options.strip_decorators: continue # When using foo.bar.baz, we only really need to grab foo and import # that. dec_support_node = dec_func while isinstance(dec_support_node, gast.Attribute): dec_support_node = dec_support_node.value if not anno.hasanno(dec_support_node, 'live_val'): raise ValueError( 'could not resolve symbol "%s" when looking up decorator "%s"' % (anno.getanno(dec_support_node, anno.Basic.QN), original_dec)) dec_support = anno.getanno(dec_support_node, 'live_val') # The tuple contains: # * the AST that represents the decorator # * the entity supporting the decorator (i.e., what we need to import) # * the name of the module that needs to be imported for this decorator # to properly resolve. # Examples: # for foo.bar, the tuple is (<ast>, <module foo>, 'foo') # for baz, the tuple is (<ast>, <module baz.__module__>, 'baz') kept_decorators.append( (dec, dec_support, anno.getanno(dec_support_node, anno.Basic.QN))) for _, dec_support, name in kept_decorators: if tf_inspect.ismodule(dec_support): self.ctx.program.additional_imports.add( 'import %s as %s' % (dec_support.__name__, name)) else: if dec_support.__module__ == '__main__': raise ValueError( 'decorator "%s" was not allowed because it is declared ' 'in the module "%s". To fix this, declare it in a separate ' 'module that we can import it from.' % (dec_support, dec_support.__module__)) self.ctx.program.additional_imports.add( 'from %s import %s' % (dec_support.__module__, name)) node.decorator_list = [dec for dec, _, _ in kept_decorators] return node