def visit(self, cfg_node): # cfg_node.value is None for the exit node, which will be visited only once if not cfg_node.value: for pred in cfg_node.prev: self.visit(pred) return if anno.hasanno(cfg_node.value, self.in_label): before = hash(anno.getanno(cfg_node.value, self.in_label)) else: before = None succs = [ anno.getanno(succ.value, self.in_label) for succ in cfg_node.next if anno.hasanno(succ.value, self.in_label) ] if succs: incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) else: incoming = frozenset() anno.setanno(cfg_node.value, self.out_label, incoming) gen, kill = self.get_gen_kill(cfg_node, incoming) anno.setanno(cfg_node.value, self.gen_label, gen) anno.setanno(cfg_node.value, self.kill_label, kill) anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) if hash(anno.getanno(cfg_node.value, self.in_label)) != before: for pred in cfg_node.prev: self.visit(pred)
def _build_source_map(node, code): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: node: An AST node of the original generated code, before the source code is generated. code: The string representation of the source code for the newly generated code. Returns: Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph generated code. """ # After we have the final generated code we reparse it to get the final line # numbers. Then we walk through the generated and original ASTs in parallel # to build the mapping between the user and generated code. new_node = parser.parse_str(code) origin_info.resolve(new_node, code) source_mapping = {} for before, after in ast_util.parallel_walk(node, new_node): # Need both checks because if origin information is ever copied over to new # nodes then we need to rely on the fact that only the original user code # has the origin annotation. if (anno.hasanno(before, anno.Basic.ORIGIN) and anno.hasanno(after, anno.Basic.ORIGIN)): source_info = anno.getanno(before, anno.Basic.ORIGIN) new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number source_mapping[new_line_number] = source_info return source_mapping
def visit(self, node): """Depth-first walking the CFG, applying dataflow info propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming) gen, kill = self.get_gen_kill(node, incoming) anno.setanno(node.value, self.gen_label, gen) anno.setanno(node.value, self.kill_label, kill) anno.setanno(node.value, self.out_label, (incoming - kill) | gen) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): # Expecting the actual type to be the second argument. if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) if not anno.hasanno(node.args[0], anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if not anno.hasanno(node.args[1], 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) target_symbol = anno.getanno(node.args[0], anno.Basic.QN) element_type = anno.getanno(node.args[1], 'live_val') # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def _build_source_map(node, code): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: node: An AST node of the original generated code, before the source code is generated. code: The string representation of the source code for the newly generated code. Returns: Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph generated code. """ # After we have the final generated code we reparse it to get the final line # numbers. Then we walk through the generated and original ASTs in parallel # to build the mapping between the user and generated code. new_node = parser.parse_str(code) origin_info.resolve(new_node, code) source_mapping = {} for before, after in ast_util.parallel_walk(node, new_node): # Need both checks because if origin information is ever copied over to new # nodes then we need to rely on the fact that only the original user code # has the origin annotation. if (anno.hasanno(before, anno.Basic.ORIGIN) and anno.hasanno(after, anno.Basic.ORIGIN)): source_info = anno.getanno(before, anno.Basic.ORIGIN) new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number source_mapping[new_line_number] = source_info return source_mapping
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if anno.getanno(node.func, 'live_val') is utils.set_element_type: if len(node.args) < 2 or len(node.args) > 3: raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) if len(node.args) == 2: target_arg, type_arg = node.args shape_arg = parser.parse_expression('None') else: target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % utils.set_element_type) # TODO(mdan): This is vulnerable to symbol renaming. element_type = type_arg element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def visit(self, cfg_node): # cfg_node.value is None for the exit node, which will be visited only once if not cfg_node.value: for pred in cfg_node.prev: self.visit(pred) return if anno.hasanno(cfg_node.value, self.in_label): before = hash(anno.getanno(cfg_node.value, self.in_label)) else: before = None succs = [ anno.getanno(succ.value, self.in_label) for succ in cfg_node.next if anno.hasanno(succ.value, self.in_label) ] if succs: incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) else: incoming = frozenset() anno.setanno(cfg_node.value, self.out_label, incoming) gen, kill = self.get_gen_kill(cfg_node, incoming) anno.setanno(cfg_node.value, self.gen_label, gen) anno.setanno(cfg_node.value, self.kill_label, kill) anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) if hash(anno.getanno(cfg_node.value, self.in_label)) != before: for pred in cfg_node.prev: self.visit(pred)
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) target_arg, type_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError( 'the first argument of "%s" must by a symbol' % self.context.type_annotation_func) if isinstance(type_arg, gast.Str): element_type = type_arg.s elif isinstance(type_arg, gast.Num): element_type = type_arg.n else: if not anno.hasanno(type_arg, 'live_val'): raise ValueError( 'the second argument of "%s" must be statically resolvable' % self.context.type_annotation_func) element_type = anno.getanno(type_arg, 'live_val') target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(definition, 'element_type', element_type) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node
def visit(self, node): """Depth-first walking the CFG, applying dataflow information propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming) gen, kill = self.get_gen_kill(node, incoming) anno.setanno(node.value, self.gen_label, gen) anno.setanno(node.value, self.kill_label, kill) anno.setanno(node.value, self.out_label, (incoming - kill) | gen) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ)
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.context.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Call(self, node): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least an argument.') anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node
def visit_Attribute(self, node): self.generic_visit(node) if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr, )) # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') if hasattr(parent_type, node.attr): # This should hold for static members like methods. # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno( node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr, )) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): if len(node.args) < 2 or len(node.args) > 3: raise ValueError( '"%s" must have either two or three parameters' % self.context.type_annotation_func) if len(node.args) == 2: target_arg, type_arg = node.args shape_arg = parser.parse_expression('None') else: target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError( 'the first argument of "%s" must by a symbol' % self.context.type_annotation_func) # TODO(mdan): This is vulnerable to symbol renaming. element_type = type_arg element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node)
def test_copyanno(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node
def test_nested_assignment(self): def test_fn(foo): a, (b, c) = foo return a, b, c node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)}) lhs = node.body[0].body[1].value.elts a = lhs[0] b = lhs[1] c = lhs[2] # TODO(mdan): change these once we have the live values propagating # correctly self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val'))
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo')) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo')
def test_duplicate(self): node = ast.If( test=ast.Num(1), body=[ast.Expr(ast.Name('bar', ast.Load()))], orelse=[]) anno.setanno(node, 'spam', 1) anno.setanno(node, 'ham', 1) anno.setanno(node.body[0], 'ham', 1) anno.dup(node, {'spam': 'eggs'}) self.assertTrue(anno.hasanno(node, 'spam')) self.assertTrue(anno.hasanno(node, 'ham')) self.assertTrue(anno.hasanno(node, 'eggs')) self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo')) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo')
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): assert anno.hasanno(node, NodeAnno.IS_LOCAL), node symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) assert anno.hasanno(node, NodeAnno.IS_PARAM), node symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__, )) elif hasattr(obj, '__class__'): obj_class = obj.__class__ anno.setanno( node, 'fqn', (obj_class.__module__, obj_class.__name__)) else: # If the symbol value is for example a primitive, then it will not # have a name. pass else: pass # TODO (mdan): Should we raise an error here? id:997 # https://github.com/imdone/tensorflow/issues/998 # Can encounter this when: # * a symbol truly lacks reference # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO (mdan): Attempt to trace its value through the local chain. id:730 # https://github.com/imdone/tensorflow/issues/731 # TODO (mdan): Use type annotations as fallback. id:700 # https://github.com/imdone/tensorflow/issues/701 if not symbol_is_modified: if node.id in self.context.arg_values: obj = self.context.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__, )) return node
def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): self._enclosing_entities.append(node) did_enter_function = True if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node return super(Base, self).visit(node) except (ValueError, AttributeError, KeyError, NotImplementedError, AssertionError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '<no source available>' six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) finally: if did_enter_function: self._enclosing_entities.pop()
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def visit_Call(self, node): node = self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): live_val = anno.getanno(node.func, 'live_val') if live_val in py_builtins.SUPPORTED_BUILTINS: node = self._convert_builtin(live_val, node.args, as_expression=True) return node
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self',): return True return False
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = self._parse_and_analyze(test_fn, {}) method_call = node.body[0].body[0].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file try: if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node return super(Base, self).visit(node) except (ValueError, AttributeError, KeyError, NotImplementedError, AssertionError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '<no source available>' six.reraise( AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2])
def _track_symbol(self, node): # This can happen when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. # TODO(mdan): This bay be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) self.scope.mark_param(qn) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, self.scope.is_modified_since_entry(qn)) anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def _track_symbol(self, node, composite_writes_alter_parent=False, writes_create_symbol=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. self.scope.mark_write(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self',): return True return False
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO (mdan): The 'self' argument is not guaranteed to be called 'self'. id:994 # https://github.com/imdone/tensorflow/issues/996 if qn.has_attr and qn.parent.qn == ('self', ): return True
def _process_variable_assignment(self, source, targets): # Special case: constructors. if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): anno.setanno(source, 'is_constructor', True) anno.setanno(source, 'type', func_obj) anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. # Multiple targets mean multiple assignment. for target in targets: # Tuple target means unpacking. if isinstance(target, gast.Tuple): for i, target_item in enumerate(target.elts): # Two cases here: # 1. Static unpacking, e.g. a, b = c, d # 2. Dynamic unpacking, e.g. a, b = c # The former case is optimized away. if isinstance(source, (gast.Tuple, gast.List)): source_item = source.elts[i] else: source_item = gast.Subscript(source, gast.Index(i), ctx=None) self._process_variable_assignment(source_item, (target_item,)) elif isinstance(target, (gast.Name, gast.Attribute)): target_symbol = anno.getanno(target, anno.Basic.QN) self.scope.setval(target_symbol, source) else: raise ValueError( 'assignment target has unknown type: %s' % target_item)
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function) node = tr.visit(node) for_node = node.body[0].body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function) node = tr.visit(node) for_node = node.body[0].body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = self._parse_and_analyze(test_fn, {'training': training}) method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(anno.getanno(node, 'foo'), 3) self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') self.assertIsNone(anno.getanno(node, 'foo', default=None))
def test_basic(self): node = ast.Name() self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(anno.getanno(node, 'foo'), 3) self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') self.assertIsNone(anno.getanno(node, 'foo', default=None))
def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): target_entity = anno.getanno(node.value.func, 'live_val') if not self._function_is_compilable(target_entity): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') if not self._should_compile(node.value, target_fqn): return node node = self._wrap_to_py_func_no_return(node.value) return node # Only the case of py_func with no return value is special. # Everything else is processed by visit_Call. self.visit(node.value) else: self.generic_visit(node) return node
def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): target_entity = anno.getanno(node.value.func, 'live_val') if not self._function_is_compilable(target_entity): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') if not self._should_compile(node.value, target_fqn): return node node = self._wrap_to_py_func_no_return(node.value) return node # Only the case of py_func with no return value is special. # Everything else is processed by visit_Call. self.visit(node.value) else: self.generic_visit(node) return node
def visit_Name(self, node): self.generic_visit(node) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): self._process_function_arg(qn) elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) if anno.hasanno(definition, 'type'): anno.setanno(node, 'type', anno.getanno(definition, 'type')) anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) if anno.hasanno(definition, 'element_type'): anno.setanno(node, 'element_type', anno.getanno(definition, 'element_type')) return node
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = self._parse_and_analyze(test_fn, {}) method_call = node.body[0].body[0].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_inner_scope(self): def test_fn(): a = [] utils.set_element_type(a, 1) for _ in a: b = [] utils.set_element_type(b, 2) return a, b node = self._parse_and_analyze(test_fn, {'utils': utils}) a, b = node.body[0].body[2].body[2].value.elts self.assertEquals(1, anno.getanno(a, 'element_type')) self.assertEquals(2, anno.getanno(b, 'element_type')) self.assertFalse(anno.hasanno(a, 'type')) self.assertFalse(anno.hasanno(b, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val'))
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_inner_scope(self): def test_fn(): a = [] utils.set_element_type(a, 1) for _ in a: b = [] utils.set_element_type(b, 2) return a, b node = self._parse_and_analyze(test_fn, {'utils': utils}) a, b = node.body[0].body[2].body[2].value.elts self.assertEquals(1, anno.getanno(a, 'element_type')) self.assertEquals(2, anno.getanno(b, 'element_type')) self.assertFalse(anno.hasanno(a, 'type')) self.assertFalse(anno.hasanno(b, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val'))
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = self._parse_and_analyze(test_fn, {'training': training}) method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def _expect_simple_symbol(self, operand): if isinstance(operand, gast.Name): return if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): return raise NotImplementedError( 'only simple local variables are supported in logical and compound ' 'comparison expressions; for example, we support "a or b" but not ' '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"')