def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): if self._should_compile(node, target_fqn): node = self._rename_compilable_function(node) else: node = self.generic_visit(node) return node elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) elif inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif inspect_utils.isnamedtuple(target_entity): # Although not compilable, we assume they are safe for graph mode. node = self.generic_visit(node) return node else: # TODO(mdan): Instert dynamic conversion here instead. raise NotImplementedError( 'py_func with return values (unknown function)') else: # Special cases # TODO(mdan): These need a systematic review - there may be more. # 1. super() calls - these are preserved. The class conversion mechanism # will ensure that they return the correct value. if ast_util.matches(node, parser.parse_expression('super(_)')): return node # 2. super().method calls - these are preserved as well, when the # conversion processes the entire class. if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and self.ctx.info.owner_type is not None): return node node = self._insert_dynamic_conversion(node) return node
def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): if self._should_compile(node, target_fqn): node = self._rename_compilable_function(node) else: node = self.generic_visit(node) return node elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) elif inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif inspect_utils.isnamedtuple(target_entity): # Although not compilable, we assume they are safe for graph mode. node = self.generic_visit(node) return node else: # TODO(mdan): Instert dynamic conversion here instead. raise NotImplementedError( 'py_func with return values (unknown function)') else: # Special cases # TODO(mdan): These need a systematic review - there may be more. # 1. super() calls - these are preserved. The class conversion mechanism # will ensure that they return the correct value. if ast_util.matches(node, 'super(_)'): return node # 2. super().method calls - these are preserved as well, when the # conversion processes the entire class. if (ast_util.matches(node, 'super(_)._(_)') and self.ctx.info.owner_type is not None): return node node = self._insert_dynamic_conversion(node) return node
def assertAstMatches(self, actual_node, expected_node_src): expected_node = gast.parse(expected_node_src).body[0] msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( pretty_printer.fmt(expected_node), pretty_printer.fmt(actual_node)) self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested: no_controls_to_gate = False if not current_dest: no_controls_to_gate = True if len(current_dest) == 1: if ast_util.matches(current_dest[0], 'return'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return ()'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return []'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return {}'): no_controls_to_gate = True if no_controls_to_gate: # TODO(mdan): There may still be something that could be done. raise ValueError( 'Unable to insert statement into the computation flow: it is not' ' followed by any computation which the statement could gate.' ) return new_nodes
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested: no_controls_to_gate = False if not current_dest: no_controls_to_gate = True if len(current_dest) == 1: if ast_util.matches(current_dest[0], 'return'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return ()'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return []'): no_controls_to_gate = True if ast_util.matches(current_dest[0], 'return {}'): no_controls_to_gate = True if no_controls_to_gate: # TODO(mdan): There may still be something that could be done. raise ValueError( 'Unable to insert statement into the computation flow: it is not' ' followed by any computation which the statement could gate.') return new_nodes
def assertAstMatches(self, actual_node, expected_node_src, expr=True): if expr: # Ensure multi-line expressions parse. expected_node = gast.parse('({})'.format(expected_node_src)).body[0] expected_node = expected_node.value else: expected_node = gast.parse(expected_node_src).body[0] msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( pretty_printer.fmt(expected_node), pretty_printer.fmt(actual_node)) self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(other_value).strip(), compiler.ast_to_source(first_value).strip())) return first_value
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.options.strip_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict',): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.options.recursive: node = self._insert_dynamic_conversion(node) return node
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if anno.hasanno(node.func, anno.Basic.QN): # Special-case a few builtins that otherwise go undetected. This # normally doesn't pose a problem, but the dict built-in doesn't # work with inspect.getargspec which is required for dynamic functions. # Note: expecting this is resilient to aliasing (e.g. # dict = an_evil_dict), because in those cases the regular mechanisms # process a simple user function. qn = anno.getanno(node.func, anno.Basic.QN) # Add items to this list as needed. if str(qn) in ('dict', ): return node if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(other_value).strip(), compiler.ast_to_source(first_value).strip())) return first_value
def visit_Call(self, node): # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if target_entity in self.ctx.program.options.strip_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' 'A decorator needs at least one positional argument.' % target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') if anno.hasanno(node.func, 'fqn'): target_fqn = anno.getanno(node.func, 'fqn') else: target_fqn = None if inspect_utils.isbuiltin(target_entity): # Note: Any builtin that passed the builtins converter is assumed to be # safe for graph mode. return node elif self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: # TODO(mdan): Should we replace these with equivalent TF ops instead? node = self._wrap_to_py_func_single_return( node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: raise NotImplementedError( 'py_func with return values (unknown function)') else: if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. return node if self.ctx.program.options.recursive: node = self._insert_dynamic_conversion(node) return node
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive for a symbol, or a default if none exist. See lang/directives.py for details on directives. Args: node: ast.AST directive: Callable[..., Any] arg: str default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default # TODO(mdan): Simplify this. arg_values = [] for def_ in defs: if (directive not in def_.directives or arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: if not ast_util.matches(arg_value, prev_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(arg_value).strip(), compiler.ast_to_source(prev_value).strip())) arg_values.append(arg_value) if not arg_values: return default arg_value, = arg_values return arg_value
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive for a symbol, or a default if none exist. See lang/directives.py for details on directives. Args: node: ast.AST directive: Callable[..., Any] arg: str default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default # TODO(mdan): Simplify this. arg_values = [] for def_ in defs: if (directive not in def_.directives or arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: if not ast_util.matches(arg_value, prev_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(arg_value).strip(), compiler.ast_to_source(prev_value).strip())) arg_values.append(arg_value) if not arg_values: return default arg_value, = arg_values return arg_value
def assertNoMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) pattern = parser.parse_expression(pattern_str) self.assertFalse(ast_util.matches(node, pattern))
def assertNoMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) pattern = parser.parse_expression(pattern_str) self.assertFalse(ast_util.matches(node, pattern))