def visit_Compare(self, node): node = self.generic_visit(node) if not all(self._has_matching_func(op) for op in node.ops): if len(node.ops) == 1: # Basic expressions are safe to leave as they are. return node else: raise NotImplementedError( 'compound expression with at least one unsupported ' 'operator: {}'.format(node.ops)) ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) binary_comparison = self._as_function( self._matching_func(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) op_tree = self._as_function('tf.logical_and', (binary_comparison, op_tree)) else: op_tree = binary_comparison left = right assert op_tree is not None return op_tree
def _track_symbol(self, node, composite_writes_alter_parent=False, writes_create_symbol=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. self.scope.mark_write(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) if self._in_return_statement: self.scope.mark_returned(qn)
def visit_Compare(self, node): node = self.generic_visit(node) ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) binary_comparison = self._as_function(self._matching_func(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) op_tree = self._as_function('ag__.and_', (op_tree, binary_comparison), args_as_lambda=True) else: op_tree = binary_comparison left = right assert op_tree is not None return op_tree
def visit_Lambda(self, node): assert not self._in_function_def_args self.state[_Lambda].enter() node = self.generic_visit(node) anno.setanno(node, anno.Static.SCOPE, self.scope) self.state[_Lambda].exit() return node
def visit_Attribute(self, node): parent_types = self.visit(node.value) # Attempt to use the static value if known. parent_value = anno.Static.VALUE.of(node.value, None) if parent_value is not None: static_value = getattr(parent_value, node.attr, None) else: # Fall back to the type if that is known. if parent_types is None: return None inferred_values = [ getattr(t, node.attr, None) for t in parent_types ] if not inferred_values: return None static_value = inferred_values[0] if static_value is None: return None if any(v is not static_value for v in inferred_values[1:]): # Static value not stable, assume it's dynamic. return None types = self.resolver.res_value(self.namespace, static_value) anno.setanno(node, anno.Static.VALUE, static_value) if __debug__: self._check_set(types) return types
def visit_Attribute(self, node): node = self.generic_visit(node) parent_val = anno.getanno(node.value, "static_value", default=None) if parent_val is not None: if hasattr(parent_val, node.attr): anno.setanno(node, "static_value", getattr(parent_val, node.attr)) return node
def visit_With(self, node): node.items = self.visit_block(node.items) node.body, definitely_returns = self._visit_statement_block( node, node.body) if definitely_returns: anno.setanno(node, STMT_DEFINITELY_RETURNS, True) return node
def _block_statement_live_out(self, node): successors = self.current_analyzer.graph.stmt_next[node] stmt_live_out = set() for s in successors: stmt_live_out.update(self.current_analyzer.in_[s]) anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out)) return node
def test_create_source_map_multiple_nodes(self): source = """ from __future__ import print_function def test_fn(x): return x + 1 """ source = textwrap.dedent(source) nodes = parser.parse_str(source, single_node=False) fake_import_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) anno.setanno(nodes[0], anno.Basic.ORIGIN, fake_import_origin) fake_function_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) anno.setanno(nodes[1], anno.Basic.ORIGIN, fake_function_origin) source_map = origin_info.create_source_map(nodes, source, 'test_filename') loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_import_origin) loc = origin_info.LineLocation('test_filename', 3) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_function_origin)
def visit_Lambda(self, node): with self.state[_Function] as fn_scope: node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if fn_scope.level > 2: return templates.replace_as_expression( 'ag__.autograph_artifact(l)', l=node) scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) fn_scope.context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options(fn_scope).to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) return node
def visit_Attribute(self, node): node = self.generic_visit(node) parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) if parent_val is not None and tf_inspect.ismodule(parent_val): if hasattr(parent_val, node.attr): anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) return node
def visit_FunctionDef(self, node): # The FunctionDef node itself has a Scope object that tracks the creation # of its name, along with the usage of any decorator accompany it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) self.scope.mark_modified(qual_names.QN(node.name)) anno.setanno(node, anno.Static.SCOPE, self.scope) self._exit_scope() # A separate Scope tracks the actual function definition. self._enter_scope(True) assert not (self._in_function_def_args or self.state[_Lambda].level) self._in_function_def_args = True node.args = self.visit(node.args) self._in_function_def_args = False # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. self._enter_scope(False) node.body = self.visit_block(node.body) anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) self._exit_scope() self._exit_scope() return node
def visit_Compare(self, node): node = self.generic_visit(node) if not all(self._has_matching_func(op) for op in node.ops): if len(node.ops) == 1: # Basic expressions are safe to leave as they are. return node else: raise NotImplementedError( 'compound expression with at least one unsupported ' 'operator: {}'.format(node.ops)) ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) binary_comparison = self._as_function(self._matching_func(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) op_tree = self._as_function('tf.logical_and', (binary_comparison, op_tree)) else: op_tree = binary_comparison left = right assert op_tree is not None return op_tree
def visit_Print(self, node): self._enter_scope(False) node.values = self.visit_block(node.values) anno.setanno(node, anno.Static.SCOPE, self.scope) anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) self._exit_scope() return node
def visit_Lambda(self, node): with self.state[_Function] as fn_scope: node = self.generic_visit(node) # TODO(mdan): Fix the tests so that we can always add this decorator. if fn_scope.level > 2: return templates.replace_as_expression( 'ag__.autograph_artifact(l)', l=node) scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol('lscope', scope.referenced) fn_scope.context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options(fn_scope).to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) return node
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if not break_used: template = """ for target in iter_: body orelse """ node = templates.replace( template, iter_=node.iter, target=node.target, body=node.body, orelse=node.orelse) new_for_node = node[0] anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression( 'ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body orelse """ node = templates.replace( template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node
def _aggregate_predecessors_defined_in(self, node): preds = self.current_analyzer.graph.stmt_prev[node] node_defined_in = set() for p in preds: node_defined_in |= set(self.current_analyzer.out[p].value.keys()) anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
def visit_Lambda(self, node): self.state[_Function].enter() node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if self.state[_Function].level > 2: self.state[_Function].exit() return node scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self.ctx.program.options.to_ast(), function_context=function_context_name, function_context_name=gast.Str(function_context_name), body=node.body) self.state[_Function].exit() return node
def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, gast.Load) and node.id in self.ctx.info.namespace: anno.setanno(node, "static_value", self.ctx.info.namespace[node.id]) return node
def _as_function(self, func_name, args): template = """ func_name(args) """ replacement = templates.replace_as_expression( template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement
def visit(self, node): types = super().visit(node) if __debug__: self._check_set(types) if types is not None: # TODO(mdan): Normalize by removing subtypes. anno.setanno(node, anno.Static.TYPES, tuple(types)) return types
def _aggregate_successors_live_in(self, node): successors = self.current_analyzer.graph.stmt_next[node] node_live_out = set() for s in successors: node_live_out.update(self.current_analyzer.in_[s]) anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out)) node = self.generic_visit(node) return node
def visit(self, node): node = super(Annotator, self).visit(node) if (self.current_analyzer is not None and isinstance(node, gast.stmt) and node in self.current_analyzer.graph.index): cfg_node = self.current_analyzer.graph.index[node] anno.setanno(node, anno.Static.LIVE_VARS_IN, frozenset(self.current_analyzer.in_[cfg_node])) return node
def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, gast.Load): defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) is_defined = bool(defs) if not is_defined and node.id in self.ctx.info.namespace: anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) return node
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node parent_origin = self.ctx.current_origin if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) try: processing_expr_node = isinstance(node, gast.Expr) if processing_expr_node: entry_expr_value = node.value result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if (processing_expr_node and isinstance(result, gast.Expr) and (result.value is not entry_expr_value)): # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced # node. if result is not node and result is not None: inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) finally: self.ctx.current_origin = parent_origin return result
def visit_Call(self, node): self._enter_scope(False) node.args = self.visit_block(node.args) node.keywords = self.visit_block(node.keywords) # TODO(mdan): Account starargs, kwargs anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) self._exit_scope() node.func = self.visit(node.func) return node
def _process_statement_directive(self, call_node, directive): if self.local_scope_level < 1: raise ValueError( '"%s" must be used inside a statement' % directive.__name__) target = self.get_local(ENCLOSING_LOOP) node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {}) node_anno[directive] = _map_args(call_node, directive) anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno) return call_node
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols(node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def visit_For(self, node): self._enter_scope(False) node.target = self.visit(node.target) node.iter = self.visit(node.iter) anno.setanno(node.iter, anno.Static.SCOPE, self.scope) self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def visit_While(self, node): self._enter_scope(False) node.test = self.visit(node.test) anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) anno.setanno(node.test, anno.Static.SCOPE, self.scope) self._exit_scope() node = self._process_parallel_blocks( node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def visit_If(self, node): self._enter_scope(False) node.test = self.visit(node.test) node_scope = self._exit_and_record_scope(node.test) anno.setanno(node, NodeAnno.COND_SCOPE, node_scope) node = self._process_parallel_blocks( node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') node = qual_names.resolve(node) anno.setanno(node, 'foo', 'bar') orig_anno = anno.getanno(node, 'foo') node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('b')}) self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def visit_While(self, node): self._enter_scope(False) node.test = self.visit(node.test) anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) anno.setanno(node.test, anno.Static.SCOPE, self.scope) self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. Returns: A tuple of the AST node for function and a String containing its source code. """ if not isinstance(nodes, (list, tuple)): nodes = (nodes,) if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None # TODO(mdan): Pull this to a separate utility. code_reader = six.StringIO(source) comment_map = {} for token in tokenize.generate_tokens(code_reader.readline): tok_type, tok_string, loc, _, _ = token srow, _ = loc if tok_type == tokenize.COMMENT: comment_map[srow] = tok_string.strip()[1:].strip() source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): if not hasattr(n, 'lineno'): continue lineno_in_body = n.lineno source_code_line = source_lines[lineno_in_body - 1] if function: source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: source_lineno = lineno_in_body function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin)
def _block_statement_live_in(self, node, entry_node): if entry_node in self.current_analyzer.graph.index: cfg_node = self.current_analyzer.graph.index[entry_node] stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) else: assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), ( 'If not matching a CFG node, must be a block statement:' ' {}'.format(entry_node)) stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN) anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node
def test_copy_clean_preserves_annotations(self): node = parser.parse_str( textwrap.dedent(""" def f(a): return a + 1 """)) anno.setanno(node.body[0], 'foo', 'bar') anno.setanno(node.body[0], 'baz', 1) new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
def _process_function_arg(self, arg_node): qn = anno.getanno(arg_node, anno.Basic.QN) arg_name = str(qn) self.scope.setval(qn, arg_node) if (len(self.enclosing_entities) == 1 and arg_name in self.entity_info.arg_types): # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_string, type_obj = self.entity_info.arg_types[arg_name] anno.setanno(arg_node, 'type', type_obj) anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
def resolve(node, source, function=None): """Adds an origin information to node and its subnodes. This allows us to map the original source code line numbers to generated source code. Args: node: gast.AST node. Should be a gast.FunctionDef. This is the node we annotate with origin information. source: Text, the source code. Should satisfy relationship `node in iter_tree(gast.parse(source))`; otherwise the lineno will be unreliable. function: The original function. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. """ if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None # TODO(mdan): Pull this to a separate utility. code_reader = six.StringIO(source) comment_map = {} for token in tokenize.generate_tokens(code_reader.readline): tok_type, tok_string, loc, _, _ = token srow, _ = loc if tok_type == tokenize.COMMENT: comment_map[srow] = tok_string.strip()[1:].strip() source_lines = source.split('\n') for n in gast.walk(node): if not hasattr(n, 'lineno'): continue within_body_offset = n.lineno - node.lineno source_code_line = source_lines[n.lineno - 1] if function: source_lineno = function_lineno + within_body_offset function_name = function.__name__ else: source_lineno = n.lineno function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin)
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.options.strip_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.options.recursive: node = self._insert_dynamic_conversion(node) return node
def test_basic(self): def test_fn(): raise ValueError() node, ctx = self.prepare(test_fn, {}) anno.setanno( node, anno.Basic.ORIGIN, origin_info.OriginInfo(None, 'test_function_name', 'test_code', 'test_comment')) node = error_handlers.transform(node, ctx) with self.compiled(node, {}) as result: with self.assertRaises(errors.GraphConstructionError): # Here we just assert that the handler works. Its correctness is # verified by errors_test.py. result.test_fn()
def visit_If(self, node): node.test = self.visit(node.test) node.body, body_definitely_returns = self._visit_statement_block( node, node.body) if body_definitely_returns: anno.setanno(node, BODY_DEFINITELY_RETURNS, True) node.orelse, orelse_definitely_returns = self._visit_statement_block( node, node.orelse) if orelse_definitely_returns: anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True) if body_definitely_returns and orelse_definitely_returns: self.state[_Block].definitely_returns = True return node
def visit_Name(self, node): if self.current_analyzer is None: # Names may appear outside function defs - for example in class # definitions. return node analyzer = self.current_analyzer cfg_node = self.current_cfg_node assert cfg_node is not None, 'name node outside of any statement?' qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Load): anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.in_[cfg_node].value.get(qn, ()))) else: anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.out[cfg_node].value.get(qn, ()))) return node
def test_create_source_map(self): source = """ def test_fn(x): return x + 1 """ source = textwrap.dedent(source) node = parser.parse_str(source) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) anno.setanno(node, anno.Basic.ORIGIN, fake_origin) source_map = origin_info.create_source_map(node, source, 'test_filename') loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)
def test_create_source_map(self): def test_fn(x): return x + 1 node, _ = parser.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) fn_node = node.body[0] anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) converted_code = compiler.ast_to_source(fn_node) source_map = origin_info.create_source_map( fn_node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)
def visit_For(self, node): node.iter = self.visit(node.iter) node.target = self.visit(node.target) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Return].used: extra_test = anno.getanno(node, 'extra_test', default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', extra_test=extra_test, control_var=self.state[_Function].do_return_var_name) else: extra_test = templates.replace_as_expression( 'ag__.not_(control_var)', control_var=self.state[_Function].do_return_var_name) anno.setanno(node, 'extra_test', extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node
def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression( 'not var_name', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = tf.constant(False) for target in iter_: (var_name,) body else: orelse """ node = templates.replace( template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) anno.setanno(node[1], 'extra_test', extra_test) return node