def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace_as_expression( 'func_name', func_name=new_name) return node
def test_static_attribute_of_ambiguous_type(self): test_self = self class TestClass1: a = 1 class TestClass2: a = 2 tc = TestClass1() class Resolver(type_inference.Resolver): def res_name(self, ns, types_ns, name): test_self.assertEqual(name, qual_names.QN('tc')) return {TestClass1, TestClass2}, None def res_value(self, ns, value): test_self.assertIn(value, (1, 2)) return {str} def test_fn(): return tc.a node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body self.assertTypes(fn_body[0].value.value, (TestClass1, TestClass2)) self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES)) self.assertFalse( anno.hasanno(fn_body[0].value.value, anno.Static.VALUE)) self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
def test_dynamic_attribute_of_typed_value(self): test_self = self class TestClass: def __init__(self): self.a = 1 tc = TestClass() class Resolver(type_inference.Resolver): def res_name(self, ns, types_ns, name): test_self.assertEqual(name, qual_names.QN('tc')) return {TestClass}, None def test_fn(): return tc.a node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body self.assertTypes(fn_body[0].value.value, TestClass) self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES)) self.assertFalse( anno.hasanno(fn_body[0].value.value, anno.Static.VALUE)) self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) processing_expr_node = False parent_origin = self.ctx.current_origin if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True elif isinstance(node, gast.Expr): processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) if processing_expr_node: entry_expr_value = node.value if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) self.ctx.current_origin = parent_origin # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if processing_expr_node: if isinstance(result, gast.Expr) and result.value != entry_expr_value: # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() if local_scope_size_at_entry != len(self._local_scope_state): raise AssertionError( 'Inconsistent local scope stack. Before entering node %s, the' ' stack had length %d, after exit it has length %d. This' ' indicates enter_local_scope and exit_local_scope are not' ' well paired.' % (node, local_scope_size_at_entry, len(self._local_scope_state))) return result
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) processing_expr_node = False parent_origin = self.ctx.current_origin if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True elif isinstance(node, gast.Expr): processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) if processing_expr_node: entry_expr_value = node.value if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) self.ctx.current_origin = parent_origin # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if processing_expr_node: if isinstance(result, gast.Expr) and result.value != entry_expr_value: # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() if local_scope_size_at_entry != len(self._local_scope_state): raise AssertionError( 'Inconsistent local scope stack. Before entering node %s, the' ' stack had length %d, after exit it has length %d. This' ' indicates enter_local_scope and exit_local_scope are not' ' well paired.' % (node, local_scope_size_at_entry, len(self._local_scope_state))) return result
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node parent_origin = self.ctx.current_origin if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) try: processing_expr_node = isinstance(node, gast.Expr) if processing_expr_node: entry_expr_value = node.value result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if (processing_expr_node and isinstance(result, gast.Expr) and (result.value is not entry_expr_value)): # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced # node. if result is not node and result is not None: inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) finally: self.ctx.current_origin = parent_origin return result
def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def _try_resolve_target(self, node): """Works for methods of objects of known type.""" if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): owner_type = anno.getanno(node, 'type') if hasattr(owner_type, node.attr): return getattr(owner_type, node.attr) else: raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % (owner_type, node.attr)) return None
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): if self._should_compile(node, target_fqn): node = self._rename_compilable_function(node) else: node = self.generic_visit(node) return node elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) elif inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif inspect_utils.isnamedtuple(target_entity): # Although not compilable, we assume they are safe for graph mode. node = self.generic_visit(node) return node else: # TODO(mdan): Instert dynamic conversion here instead. raise NotImplementedError( 'py_func with return values (unknown function)') else: # Special cases # TODO(mdan): These need a systematic review - there may be more. # 1. super() calls - these are preserved. The class conversion mechanism # will ensure that they return the correct value. if ast_util.matches(node, parser.parse_expression('super(_)')): return node # 2. super().method calls - these are preserved as well, when the # conversion processes the entire class. if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and self.ctx.info.owner_type is not None): return node node = self._insert_dynamic_conversion(node) return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): if self._should_compile(node, target_fqn): node = self._rename_compilable_function(node) else: node = self.generic_visit(node) return node elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) elif inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif inspect_utils.isnamedtuple(target_entity): # Although not compilable, we assume they are safe for graph mode. node = self.generic_visit(node) return node else: # TODO(mdan): Instert dynamic conversion here instead. raise NotImplementedError( 'py_func with return values (unknown function)') else: # Special cases # TODO(mdan): These need a systematic review - there may be more. # 1. super() calls - these are preserved. The class conversion mechanism # will ensure that they return the correct value. if ast_util.matches(node, 'super(_)'): return node # 2. super().method calls - these are preserved as well, when the # conversion processes the entire class. if (ast_util.matches(node, 'super(_)._(_)') and self.ctx.info.owner_type is not None): return node node = self._insert_dynamic_conversion(node) return node
def test_duplicate(self): node = ast.If(test=ast.Num(1), body=[ast.Expr(ast.Name('bar', ast.Load()))], orelse=[]) anno.setanno(node, 'spam', 1) anno.setanno(node, 'ham', 1) anno.setanno(node.body[0], 'ham', 1) anno.dup(node, {'spam': 'eggs'}) self.assertTrue(anno.hasanno(node, 'spam')) self.assertTrue(anno.hasanno(node, 'ham')) self.assertTrue(anno.hasanno(node, 'eggs')) self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict', ): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.options.strip_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.options.recursive: node = self._insert_dynamic_conversion(node) return node
def test_no_inference_on_unknown_operand_types(self): # No information on types of a and b, see TestResolver. def magic_no_types(a, b): return a < b, a - b node, _ = TestTranspiler().transform(magic_no_types, None) fn_body = node.body # With no information on operand types, the operators will assert nothing. self.assertFalse( anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES)) self.assertFalse( anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
def test_duplicate(self): node = ast.If( test=ast.Num(1), body=[ast.Expr(ast.Name('bar', ast.Load()))], orelse=[]) anno.setanno(node, 'spam', 1) anno.setanno(node, 'ham', 1) anno.setanno(node.body[0], 'ham', 1) anno.dup(node, {'spam': 'eggs'}) self.assertTrue(anno.hasanno(node, 'spam')) self.assertTrue(anno.hasanno(node, 'ham')) self.assertTrue(anno.hasanno(node, 'eggs')) self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Constant(self, node): self.set_local( 'string', self.get_local('string', default='') + str(node.value)) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._simple_context()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while 4: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) for_node = node.body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('3a2bc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('41', anno.getanno(while_node, 'test'))
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) has_extra_test = anno.hasanno(node, 'extra_test') if loop_state: if has_extra_test: # Loop with early stopping (e.g. break or return) extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved_symbols) loop_nodes = self._for_loop_with_extra_test( loop_state, state_ssf, state_ast_tuple, node, extra_test_name, extra_test, body_name, node_body) else: # Loop with loop-carried state and no early stopping loop_nodes = self._for_loop_with_state( loop_state, state_ssf, state_ast_tuple, node, body_name, node_body) else: # Loop with no loop-carried state and no early stopping assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) ' 'should create state variables.') loop_nodes = self._for_loop_without_state(node, body_name, node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + loop_nodes
def can_ignore(self, node): """Returns True if the node can safely be assumed not to touch variables.""" ast_node = node.ast_node if anno.hasanno(ast_node, anno.Basic.SKIP_PROCESSING): return True return isinstance(ast_node, (gast.Break, gast.Continue, gast.Raise, gast.Pass))
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def test_no_inference_on_unknown_operand_types(self): class Resolver(type_inference.Resolver): def res_arg(self, ns, types_ns, f_name, name, type_anno): return None def test_fn(a, b): return a < b, a - b node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body # With no information on operand types, the operators will infer nothing. self.assertFalse( anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES)) self.assertFalse( anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def test_static_attribute_of_typed_value(self): test_self = self class TestClass: a = 1 tc = TestClass() class Resolver(type_inference.Resolver): def res_name(self, ns, types_ns, name): test_self.assertEqual(name, qual_names.QN('tc')) return {TestClass}, None def res_value(self, ns, value): test_self.assertIs(value, tc.a) return {str} def test_fn(): return tc.a node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body self.assertTypes(fn_body[0].value.value, TestClass) self.assertTypes(fn_body[0].value, str) # Resolver is SOT self.assertFalse( anno.hasanno(fn_body[0].value.value, anno.Static.VALUE)) self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE), 1)
def visit_For(self, node): self.builder.begin_statement(node) self._enter_lexical_scope(node) self.builder.enter_section(node) # Note: Strictly speaking, this should be node.target + node.iter. # However, the activity analysis accounts for this inconsistency, # so dataflow analysis produces the correct values. self.generic_visit(node.iter) self.builder.enter_loop_section(node, node.iter) # Also include the "extra loop test" annotation, to capture things like the # control variable for return and break in for loops. if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): self._process_basic_statement( anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) for stmt in node.body: self.visit(stmt) self.builder.exit_loop_section(node) # Note: although the orelse is technically part of the loop node, # they don't count as loop bodies. For example, a break in the loop's # orelse will affect the parent loop, not the current one. self._exit_lexical_scope(node) for stmt in node.orelse: self.visit(stmt) self.builder.exit_section(node) self.builder.end_statement(node)
def test_property_of_typed_value(self): test_self = self class TestClass: @property def a(self): return 1 tc = TestClass() class Resolver(type_inference.Resolver): def res_name(self, ns, types_ns, name): test_self.assertEqual(name, qual_names.QN('tc')) return {TestClass}, None def res_value(self, ns, value): test_self.assertIs(value, TestClass.a) test_self.assertNotEqual( value, 1) # Can't evaluate property of class. return {property} def test_fn(): return tc.a node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body self.assertTypes(fn_body[0].value.value, TestClass) self.assertTypes(fn_body[0].value, property) self.assertFalse( anno.hasanno(fn_body[0].value.value, anno.Static.VALUE)) self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE), TestClass.a)
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = self._parse_and_analyze(test_fn, {}) method_call = node.body[0].body[0].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self',): return True return False
def _track_symbol(self, node, composite_writes_alter_parent=False, writes_create_symbol=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. self.scope.mark_write(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def visit_Name(self, node): # Only the loads which existed in the original code are overloaded. if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS): return node if isinstance(node.ctx, gast.Load): node = templates.replace_as_expression('ag__.ld(var_)', var_=node) return node
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError( 'Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_entity = self._try_resolve_target(node.func) if target_entity is not None: # This may be reached when "calling" a callable attribute of an object. # For example: # # self.fc = tf.keras.layers.Dense() # self.fc() # for mod in self.ctx.program.uncompiled_modules: if target_entity.__module__.startswith(mod[0] + '.'): return False # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. if hasattr(target_entity, '__pyct_is_compile_decorator'): return False if target_entity in self.ctx.program.options.strip_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.options.strip_decorators): return False return True
def visit_FunctionDef(self, node): self.state[_Function].enter() # Note: if the conversion process ever creates helper functions, this # assumption will no longer hold. assert anno.hasanno(node, 'function_context_name'), ( 'The function_scopes converter always creates a scope for functions.' ) self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node.args = self.visit(node.args) node.body = self.visit_block(node.body) if self.state[_Function].level < 2: # Top-level functions lose their decorator because the conversion is # always just-in-time and by the time it happens the decorators are # already set to be applied. node.decorator_list = [] else: # Inner functions are converted already, so we insert a decorator to # prevent double conversion. Double conversion would work too, but this # saves the overhead. node.decorator_list.append( parser.parse_expression('ag__.do_not_convert_internal')) if node.returns: node.returns = self.visit(node.returns) self.state[_Function].exit() return node
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. if x.y is live whether x needs to be added. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break)), type( node.ast_node) live_in = prev_live_in live_out = live_in self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self', ): return True return False
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_entity = self._try_resolve_target(node.func) if target_entity is not None: # This may be reached when "calling" a callable attribute of an object. # For example: # # self.fc = tf.keras.layers.Dense() # self.fc() # for mod in self.ctx.program.uncompiled_modules: if target_entity.__module__.startswith(mod[0] + '.'): return False # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. if hasattr(target_entity, '__pyct_is_compile_decorator'): return False if target_entity in self.ctx.program.options.strip_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_decorator_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.options.strip_decorators): return False return True
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.read | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. whether x needs to be added if x.y is live. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified | node_scope.deleted live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: assert self.can_ignore(node), (node.ast_node, node) live_out = set() for n in node.next: live_out |= self.in_[n] live_in = live_out self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def visit_Call(self, node): node = self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): live_val = anno.getanno(node.func, 'live_val') if live_val in py_builtins.SUPPORTED_BUILTINS: node = self._convert_builtin(live_val, node.args, as_expression=True) return node
def _track_symbol(self, node, composite_writes_alter_parent=False, writes_create_symbol=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. self.scope.mark_write(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) has_extra_test = anno.hasanno(node, 'extra_test') if loop_state: if has_extra_test: # Loop with early stopping (e.g. break or return) extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved_symbols) node = self._create_for_loop_early_stopping( loop_state, state_ssf, state_ast_tuple, node, extra_test_name, extra_test, body_name, node_body) else: # Loop with loop-carried state and no early stopping node = self._create_for_loop_with_state( loop_state, state_ssf, state_ast_tuple, node, body_name, node_body) else: # Loop with no loop-carried state and no early stopping assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) ' 'should create state variables.') node = self._create_for_loop_without_state(node, body_name, node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): # Extract all string constants from the block. def visit_Str(self, node): self.set_local('string', self.get_local('string', default='') + node.s) return self.generic_visit(node) def _annotate_result(self, node): self.enter_local_scope() node = self.generic_visit(node) anno.setanno(node, 'test', self.get_local('string')) self.exit_local_scope() return node def visit_While(self, node): return self._annotate_result(node) def visit_For(self, node): return self._annotate_result(node) tr = TestTransformer(self._simple_context()) def test_function(a): """Docstring.""" assert a == 'This should not be counted' for i in range(3): _ = 'a' if i > 2: return 'b' else: _ = 'c' while True: raise '1' return 'nor this' node, _ = parser.parse_entity(test_function, future_features=()) node = tr.visit(node) for_node = node.body[2] while_node = for_node.body[1].orelse[1] self.assertFalse(anno.hasanno(for_node, 'string')) self.assertEqual('abc', anno.getanno(for_node, 'test')) self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test'))
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace( template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) return node
def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) # When inside a lambda, ignore any of the lambda's arguments. # This includes attributes or slices of those arguments. for l in self.state[_Lambda]: if qn in l.args: return if qn.owner_set & set(l.args): return # When inside a comprehension, ignore any of the comprehensions's targets. # This includes attributes or slices of those arguments. # This is not true in Python2, which leaks symbols. if six.PY3: for l in self.state[_Comprehension]: if qn in l.targets: return if qn.owner_set & set(l.targets): return if isinstance(node.ctx, gast.Store): # In comprehensions, modified symbols are the comprehension targets. if six.PY3 and self.state[_Comprehension].level > 0: # Like a lambda's args, they are tracked separately in Python3. self.state[_Comprehension].targets.add(qn) else: self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_modified(qn.parent) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): if self._in_function_def_args: # In function defs have the meaning of defining a variable. self.scope.mark_modified(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) elif self.state[_Lambda].level: # In lambdas, they are tracked separately. self.state[_Lambda].args.add(qn) else: # TODO(mdan): Is this case possible at all? raise NotImplementedError( 'Param "{}" outside a function arguments or lambda.'. format(qn)) elif isinstance(node.ctx, gast.Del): # The read matches the Python semantics - attempting to delete an # undefined symbol is illegal. self.scope.mark_read(qn) self.scope.mark_deleted(qn) else: raise ValueError('Unknown context {} for node "{}".'.format( type(node.ctx), qn))
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace(template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace(template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol( 'extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol( 'loop_body', reserved_symbols), body=node_body) return node
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = self._parse_and_analyze(test_fn, {}) method_call = node.body[0].body[0].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): target_entity = anno.getanno(node.value.func, 'live_val') if not self._function_is_compilable(target_entity): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') if not self._should_compile(node.value, target_fqn): return node node = self._wrap_to_py_func_no_return(node.value) return node # Only the case of py_func with no return value is special. # Everything else is processed by visit_Call. self.visit(node.value) else: self.generic_visit(node) return node
def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) # When inside a lambda, ignore any of the lambda's arguments. # This includes attributes or slices of those arguments. for l in self.state[_Lambda]: if qn in l.args: return if qn.owner_set & set(l.args): return # When inside a comprehension, ignore any of the comprehensions's targets. # This includes attributes or slices of those arguments. # This is not true in Python2, which leaks symbols. if six.PY3: for l in self.state[_Comprehension]: if qn in l.targets: return if qn.owner_set & set(l.targets): return if isinstance(node.ctx, gast.Store): # In comprehensions, modified symbols are the comprehension targets. if six.PY3 and self.state[_Comprehension].level > 0: # Like a lambda's args, they are tracked separately in Python3. self.state[_Comprehension].targets.add(qn) else: self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_modified(qn.parent) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): if self._in_function_def_args: # In function defs have the meaning of defining a variable. self.scope.mark_modified(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) elif self.state[_Lambda].level: # In lambdas, they are tracked separately. self.state[_Lambda].args.add(qn) else: # TODO(mdan): Is this case possible at all? raise NotImplementedError( 'Param "{}" outside a function arguments or lambda.'.format(qn)) elif isinstance(node.ctx, gast.Del): # The read matches the Python semantics - attempting to delete an # undefined symbol is illegal. self.scope.mark_read(qn) self.scope.mark_deleted(qn) else: raise ValueError('Unknown context {} for node "{}".'.format( type(node.ctx), qn))
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def test_constructor_detection_builtin_class(self): def test_fn(x): res = zip(x) return res node = self._parse_and_analyze(test_fn, {}) call_node = node.body[0].body[0].value self.assertFalse(anno.hasanno(call_node, 'is_constructor'))
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = self._parse_and_analyze(test_fn, {'training': training}) method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def _expect_simple_symbol(self, operand): if isinstance(operand, gast.Name): return if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): return raise NotImplementedError( 'only simple local variables are supported in logical and compound ' 'comparison expressions; for example, we support "a or b" but not ' '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"')
def test_copy_clean_preserves_annotations(self): node = parser.parse_str( textwrap.dedent(""" def f(a): return a + 1 """)) anno.setanno(node.body[0], 'foo', 'bar') anno.setanno(node.body[0], 'baz', 1) new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
def test_nested_unpacking(self): class Foo(object): pass class Bar(object): pass def test_fn(): a, (b, c) = (Foo(), (Bar(), Foo())) return a, b, c node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts self.assertEquals(anno.getanno(a, 'type'), Foo) self.assertEquals(anno.getanno(b, 'type'), Bar) self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val'))
def _block_statement_live_in(self, node, entry_node): if entry_node in self.current_analyzer.graph.index: cfg_node = self.current_analyzer.graph.index[entry_node] stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) else: assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), ( 'If not matching a CFG node, must be a block statement:' ' {}'.format(entry_node)) stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN) anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node
def visit_Call(self, node): node = self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): live_val = anno.getanno(node.func, 'live_val') try: if live_val in py_builtins.SUPPORTED_BUILTINS: node = self._convert_builtin(live_val, node.args, as_expression=True) except TypeError: # Not everything in Python is hashable. If it isn't then it's definitely # not a supported built-in. return node return node
def test_constructor_data_dependent(self): def test_fn(x): if x > 0: opt = training.GradientDescentOptimizer(0.1) else: opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) node = self._parse_and_analyze(test_fn, {'training': training}) method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_function_variables(self): def bar(): pass def test_fn(): foo = bar foo() node = self._parse_and_analyze(test_fn, {'bar': bar}) method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val'))
def test_basic(self): node = ast.Name() self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) self.assertEqual(anno.keys(node), {'foo'}) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(anno.getanno(node, 'foo'), 3) self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') self.assertIsNone(anno.getanno(node, 'foo', default=None))