def test_resolve(self): def test_fn(x): """Docstring.""" return x # comment node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) origin = anno.getanno(fn_node, anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') self.assertEqual(origin.comment, 'comment')
def _process_variable_assignment(self, source, targets): # Special case: constructors. if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. # Multiple targets mean multiple assignment. for target in targets: # Tuple target means unpacking. if isinstance(target, gast.Tuple): for i, target_item in enumerate(target.elts): # Two cases here: # 1. Static unpacking, e.g. a, b = c, d # 2. Dynamic unpacking, e.g. a, b = c # The former case is optimized away. if isinstance(source, (gast.Tuple, gast.List)): source_item = source.elts[i] else: source_item = gast.Subscript(source, gast.Index(i), ctx=None) self._process_variable_assignment(source_item, (target_item,)) elif isinstance(target, (gast.Name, gast.Attribute)): target_symbol = anno.getanno(target, anno.Basic.QN) self.scope.setval(target_symbol, source) else: raise ValueError( 'assignment target has unknown type: %s' % target_item)
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if anno.getanno(node.func, 'live_val') is utils.set_element_type: if len(node.args) < 2 or len(node.args) > 3: raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) if len(node.args) == 2: target_arg, type_arg = node.args shape_arg = parser.parse_expression('None') else: target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % utils.set_element_type) # TODO(mdan): This is vulnerable to symbol renaming. element_type = type_arg element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def test_if_attributes(self): def test_fn(a): if a > 0: a.b = -a.c d = 2 * a else: a.b = a.c d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c'), ('a.b', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('a', 'a.c', 'd'), ('a.b', 'd'), ('a', 'd'), )
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): # Expecting the actual type to be the second argument. if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) if not anno.hasanno(node.args[0], anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if not anno.hasanno(node.args[1], 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) target_symbol = anno.getanno(node.args[0], anno.Basic.QN) element_type = anno.getanno(node.args[1], 'live_val') # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def test_nested_function(self): def test_fn(a): def f(x): y = x * x return y b = a for i in a: c = b b -= f(i) return b, c node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) self.assertScopeIsRmc( anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', 'y', ))
def test_call_args_subscripts(self): def foo(*_): pass def test_fn(a): b = 1 c = 2 foo(a[0], a[b]) return a[c] node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a[0]', 'a[b]', 'b'), (), (), ) self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), ('b', 'c'), ('a', 'b', 'c'), )
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node
def test_call_with_composite_names(self): def foo(*_): pass def test_fn(a): foo(a.b, a.c) if a > 0: a.b = 2 else: d = 2 d.e = a.c f = d.e + 1 a.c = f node = self._parse_and_analyze(test_fn) call_node = node.body[0].body[0].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), ()) if_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def visit(self, node): """Depth-first walking the CFG, applying dataflow info propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming) gen, kill = self.get_gen_kill(node, incoming) anno.setanno(node.value, self.gen_label, gen) anno.setanno(node.value, self.kill_label, kill) anno.setanno(node.value, self.out_label, (incoming - kill) | gen) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ)
def visit_For(self, node): node.target = self.visit(node.target) node.body = self._process_block( anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) node.orelse = self._process_block( anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) return node
def _build_source_map(node, code): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: node: An AST node of the original generated code, before the source code is generated. code: The string representation of the source code for the newly generated code. Returns: Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph generated code. """ # After we have the final generated code we reparse it to get the final line # numbers. Then we walk through the generated and original ASTs in parallel # to build the mapping between the user and generated code. new_node = parser.parse_str(code) origin_info.resolve(new_node, code) source_mapping = {} for before, after in ast_util.parallel_walk(node, new_node): # Need both checks because if origin information is ever copied over to new # nodes then we need to rely on the fact that only the original user code # has the origin annotation. if (anno.hasanno(before, anno.Basic.ORIGIN) and anno.hasanno(after, anno.Basic.ORIGIN)): source_info = anno.getanno(before, anno.Basic.ORIGIN) new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number source_mapping[new_line_number] = source_info return source_mapping
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def visit(self, cfg_node): # cfg_node.value is None for the exit node, which will be visited only once if not cfg_node.value: for pred in cfg_node.prev: self.visit(pred) return if anno.hasanno(cfg_node.value, self.in_label): before = hash(anno.getanno(cfg_node.value, self.in_label)) else: before = None succs = [ anno.getanno(succ.value, self.in_label) for succ in cfg_node.next if anno.hasanno(succ.value, self.in_label) ] if succs: incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) else: incoming = frozenset() anno.setanno(cfg_node.value, self.out_label, incoming) gen, kill = self.get_gen_kill(cfg_node, incoming) anno.setanno(cfg_node.value, self.gen_label, gen) anno.setanno(cfg_node.value, self.kill_label, kill) anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) if hash(anno.getanno(cfg_node.value, self.in_label)) != before: for pred in cfg_node.prev: self.visit(pred)
def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. dtype = anno.getanno( original_call_node.func.value, 'element_type', default=templates.replace_as_expression('None')) shape = anno.getanno( original_call_node.func.value, 'element_shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace( template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape)
def test_reaching(self): def f(x): print(x) while True: x = x x = x return x node, ctx = self._parse_and_analyze(f, {}) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body # Only the argument reaches the expression def_in = anno.getanno(body[0], 'definitions_in') # One element, x, from arguments self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) while_body = body[1].body def_in = anno.getanno(while_body[0], 'definitions_in') # One definition, two possible sources. # - One from an assignment (if the loop is entered) # - The other from the arguments (if loop is not entered) self.assertEqual( set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) def_in = anno.getanno(while_body[1], 'definitions_in') # If we've reached this line, the only reaching definition of x is the # Assign node in previous line self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) def_in = anno.getanno(body[2], 'definitions_in') # Same situation as while_body[0] self.assertEqual( set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
def test_if_subscripts(self): def test_fn(a, b, c, e): if a > 0: a[b] = -a[c] d = 2 * a else: a[0] = e d = 1 return d node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), ('a', 'a[b]', 'd'), ('d',), ) # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), ('a', 'a[0]', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('a', 'd', 'a[b]', 'a[0]'), ('a', 'b', 'c', 'd', 'e'), )
def visit_With(self, node): self.generic_visit(node) incoming = anno.getanno(node.body[0], self.in_label) for item in node.items: incoming |= anno.getanno(item, self.in_label) outgoing = anno.getanno(node.body[-1], self.out_label) anno.setanno(node, self.in_label, incoming) anno.setanno(node, self.out_label, outgoing)
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual((test_function_node, ), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def visit_While(self, node): self.generic_visit(node) incoming = anno.getanno(node.body[0], self.in_label) incoming |= anno.getanno(node.test, self.in_label) outgoing = anno.getanno(node.body[-1], self.out_label) if node.orelse: orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) outgoing = self.transfer_fn(outgoing, orelse_outgoing) anno.setanno(node, self.in_label, incoming) anno.setanno(node, self.out_label, outgoing)
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def visit_For(self, node): self.generic_visit(node) incoming = set(anno.getanno(node.body[0], self.in_label)) incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) outgoing = anno.getanno(node.body[-1], self.out_label) if node.orelse: orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) outgoing = self.transfer_fn(outgoing, orelse_outgoing) anno.setanno(node, self.in_label, frozenset(incoming)) anno.setanno(node, self.out_label, outgoing)
def visit_For(self, node): self.generic_visit(node) incoming = set(anno.getanno(node.body[0], self.in_label)) incoming -= set((anno.getanno(node.target, anno.Basic.QN), )) outgoing = anno.getanno(node.body[-1], self.out_label) if node.orelse: orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) outgoing = self.transfer_fn(outgoing, orelse_outgoing) anno.setanno(node, self.in_label, frozenset(incoming)) anno.setanno(node, self.out_label, outgoing)
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_cond'): extra_cond = anno.getanno(node, 'extra_cond') extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) else: extra_cond = parser.parse_expression('True') template = """ def extra_cond_name(state_ssf): return extra_cond_expr def body_name(iterate, state_ssf): body return state_ssf, state_ast_tuple = __ops.for_loop( iterated, extra_cond_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iterated=node.iter, iterate=node.target, extra_cond_name=self.context.namer.new_symbol('extra_cond', all_referenced), extra_cond_expr=extra_cond, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_cond'): extra_cond = anno.getanno(node, 'extra_cond') extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) else: extra_cond = parser.parse_expression('True') template = """ def extra_cond_name(state_ssf): return extra_cond_expr def body_name(iterate, state_ssf): body return state_ssf, state_ast_tuple = ag__.for_loop( iterated, extra_cond_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iterated=node.iter, iterate=node.target, extra_cond_name=self.context.namer.new_symbol('extra_cond', all_referenced), extra_cond_expr=extra_cond, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): # The choice of note to assign to is arbitrary. Using Assign because it's # easy to find in the tree. def visit_Assign(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) # This will show up in the lambda function. def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 class TestClass(object): def test_method(self): b = 0 def inner_function(x): c = 0 d = lambda y: (x + y) return c, d return b, inner_function return a, TestClass node, _ = parser.parse_entity(test_function) node = tr.visit(node) test_function_node = node.body[0] test_class = test_function_node.body[1] test_method = test_class.body[0] inner_function = test_method.body[1] lambda_node = inner_function.body[1].value a = test_function_node.body[0] b = test_method.body[0] c = inner_function.body[0] lambda_expr = lambda_node.body self.assertEqual( (test_function_node,), anno.getanno(a, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method), anno.getanno(b, 'enclosing_entities')) self.assertEqual( (test_function_node, test_class, test_method, inner_function), anno.getanno(c, 'enclosing_entities')) self.assertEqual((test_function_node, test_class, test_method, inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities'))
def _validate_no_live_vars_created(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) live_vars_created_in_body = live_vars_out & body_scope.created if live_vars_created_in_body: raise ValueError( 'The following variables are created inside the loop and used later:' '\n%s\n' 'Variables must be declared outside loops because loops may not' ' necessarily execute.' % self._fmt_symbol_list(live_vars_created_in_body))
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = self._parse_and_analyze(test_fn, {'training': training}) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def _validate_no_live_vars_created(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) live_vars_created_in_body = live_vars_out & body_scope.created if live_vars_created_in_body: raise ValueError( 'The following variables are created inside the loop and used later:' '\n%s\n' 'Variables must be declared outside loops because loops may not' ' necessarily execute.' % self._fmt_symbol_list( live_vars_created_in_body))
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: nodes: Iterable[ast.AST, ...] code: Text filename: Optional[Text] indices_in_code: Union[int, Iterable[int, ...]], the positions at which nodes appear in code. The parser always returns a module when parsing code. This argument indicates the position in that module's body at which the corresponding of node should appear. Returns: Dict[CodeLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse_str(code) reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] resolve(reparsed_nodes, code) result = {} for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue line_loc = LineLocation(filename, final_info.loc.lineno) existing_origin = result.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue result[line_loc] = origin_info return result
def test_primitive_values(self): a = None def test_fn(): return a node = self._parse_and_analyze(test_fn, {'a': True}) retval_node = node.body[0].body[0].value if six.PY2: self.assertEqual( anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) else: self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
def test_type_annotation(self): class Foo(object): pass def test_fn(): f = [] f = utils.set_element_type(f, Foo) return f node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_def = node.body[0].body[0].value self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node
def test_for(self): def test_fn(a): b = a for _ in a: c = b b -= 1 return b, c node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIsRmc(anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b', ), ('b', 'c'), ('c', )) self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c', '_'), ('a', 'b', 'c', '_'))
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = self._parse_and_analyze(test_fn, {'session': session}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) method_call = node.body[0].body[0].body[0].value.func self.assertEquals(session.Session.run, anno.getanno(method_call, 'live_val'))
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo')) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo')
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, NodeAnno.IS_LOCAL), node symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) assert anno.hasanno(node, NodeAnno.IS_PARAM), node symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__, )) elif hasattr(obj, '__class__'): obj_class = obj.__class__ anno.setanno( node, 'fqn', (obj_class.__module__, obj_class.__name__)) else: # If the symbol value is for example a primitive, then it will not # have a name. pass else: pass # TODO (mdan): Should we raise an error here? id:997 # https://github.com/imdone/tensorflow/issues/998 # Can encounter this when: # * a symbol truly lacks reference # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO (mdan): Attempt to trace its value through the local chain. id:730 # https://github.com/imdone/tensorflow/issues/731 # TODO (mdan): Use type annotations as fallback. id:700 # https://github.com/imdone/tensorflow/issues/701 if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__, )) return node
def visit_While(self, node): self.generic_visit(node.test) node.body = self._process_loop_block( node.body, anno.getanno(node, NodeAnno.BODY_SCOPE)) for n in node.orelse: self.generic_visit(n) return node
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def test_nested_if(self): def test_fn(b): if b > 0: if b < 5: a = b else: a = b * b return a node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIsRmc(anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b', ), ('a', ), ('a', )) self.assertScopeIsRmc( anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b', ), ('a', ), ('a', ))
def _track_symbol(self, node): # This can happen when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. # TODO(mdan): This bay be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) self.scope.mark_param(qn) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, self.scope.is_modified_since_entry(qn)) anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): # TODO(mdan): It might make more sense to wave them through. raise NotImplementedError('non-index slice') if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node dtype = anno.getanno(node.value, 'element_type', default=templates.replace_as_expression('None')) template = """ ag__.get_item( target, key, opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression(template, target=node.value, key=node.slice, dtype=dtype)
def visit_FunctionDef(self, node): parent_analyzer = self.current_analyzer subgraph = self.graphs[node] # Preorder tree processing: # 1. if this is a child function, the parent was already analyzed and it # has the proper state value for the subgraph's entry # 2. analyze the current function body # 2. recursively walk the subtree; child functions will be processed analyzer = Analyzer(subgraph, self.definition_factory) if parent_analyzer is not None: # Wire the state between the two subgraphs' analyzers. parent_out_state = parent_analyzer.out[ parent_analyzer.graph.index[node]] # Exception: symbols modified in the child function are local to it body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) parent_out_state -= body_scope.modified analyzer.extra_in[node.args] = parent_out_state # Complete the analysis for the local function and annotate its body. analyzer.visit_forward() # Recursively process any remaining subfunctions. self.current_analyzer = analyzer # Note: not visiting name, decorator_list and returns because they don't # apply to this anlysis. # TODO(mdan): Should we still process the function name? node.args = self.visit(node.args) node.body = self.visit_block(node.body) self.current_analyzer = parent_analyzer return node
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = tf.constant(False) while code and not var_name: body else: orelse """ node = templates.replace(template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) return node