def assertAstMatches(self, actual_node, expected_node_src): expected_node = gast.parse(expected_node_src).body[0] msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( pretty_printer.fmt(expected_node), pretty_printer.fmt(actual_node)) self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def assertAstMatches(self, actual_node, expected_node_src, expr=True): if expr: # Ensure multi-line expressions parse. expected_node = gast.parse('({})'.format(expected_node_src)).body[0] expected_node = expected_node.value else: expected_node = gast.parse(expected_node_src).body[0] msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( pretty_printer.fmt(expected_node), pretty_printer.fmt(actual_node)) self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 try: result, source = compiler.ast_to_object(node, include_source_map=True) # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) fake_ag.__dict__[ 'function_scope'] = function_wrapping.function_scope result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: NotImplementedError: if entity is of a type that is not yet supported. """ logging.log(1, 'Converting %s', o) nodes, name, entity_info = convert_func_to_ast(o, program_ctx) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def test_unicode_bytes(self): source = textwrap.dedent(''' def f(): return b'b', u'u', 'depends_py2_py3' ''') node = ast.parse(source) self.assertIsNotNone(pretty_printer.fmt(node))
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): nodes, name, entity_info = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'python/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def compiled(self, node, namespace, symbols=()): source = None self.dynamic_calls = [] # See api.converted_call def converted_call(f, args, kwargs, unused_opts=None, unused_function_ctx=None): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) if kwargs is None: kwargs = {} return f(*args, **kwargs) def fake_autograph_artifact(f): setattr(f, 'fake_autograph_artifact', True) return f try: result, source, source_map = loader.load_ast( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.FunctionScope = function_wrappers.FunctionScope fake_ag.autograph_artifact = fake_autograph_artifact result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending source code:\n%s' % source) raise
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = convert_class_to_ast(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif tf_inspect.ismethod(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif hasattr(o, '__class__'): # Note: this should only be raised when attempting to convert the object # directly. converted_call should still support it. raise NotImplementedError( 'cannot convert entity "{}": object conversion is not yet' ' supported.'.format(o)) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def test_format(self): node = ast.FunctionDef(name='f', args=ast.arguments( args=[ast.Name(id='a', ctx=ast.Param())], vararg=None, kwarg=None, defaults=[]), body=[ ast.Return( ast.BinOp(op=ast.Add(), left=ast.Name(id='a', ctx=ast.Load()), right=ast.Num(1))) ], decorator_list=[], returns=None) # Just checking for functionality, the color control characters make it # difficult to inspect the result. self.assertIsNotNone(pretty_printer.fmt(node))
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append(args) return 7 class ConversionOptions(object): """Mock version of api.ConversionOptions.""" def __init__(self, recursive): self.recursive = recursive @classmethod def new(cls, recursive): cls(recursive) try: result, source = compiler.ast_to_object(node, include_source_map=True) result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] def converted_call(*args): """Mock version of api.converted_call.""" self.dynamic_calls.append( args[3:]) # args only; see api.converted_call return RESULT_OF_MOCK_CONVERTED_CALL try: result, source, source_map = compiler.ast_to_object( node, include_source_map=True) # TODO(mdan): Move the unparsing from converter into pyct and reuse here. # TODO(mdan): Move this into self.prepare() result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call, converter.ConversionOptions) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__.update(special_functions.__dict__) fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature fake_ag.utils = utils fake_ag.rewrite_graph_construction_error = ( errors.rewrite_graph_construction_error) fake_ag.function_scope = function_wrapping.function_scope result.ag__ = fake_ag result.ag_source_map__ = source_map for k, v in namespace.items(): result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) else: print('Offending compiled code:\n%s' % source) raise
def debug_print(self, node): """Helper method useful for debugging. Prints the AST.""" if __debug__: print(pretty_printer.fmt(node)) return node
def create_source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: nodes: Iterable[ast.AST, ...] code: Text filename: Optional[Text] indices_in_code: Union[int, Iterable[int, ...]], the positions at which nodes appear in code. The parser always returns a module when parsing code. This argument indicates the position in that module's body at which the corresponding of node should appear. Returns: Dict[LineLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse_str(code) reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] for node in reparsed_nodes: resolve(node, code) result = {} try: for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue line_loc = LineLocation(filename, final_info.loc.lineno) existing_origin = result.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue result[line_loc] = origin_info except ValueError: if logging.has_verbosity(3): for n, rn in zip(nodes, reparsed_nodes): nodes_str = pretty_printer.fmt(n, color=False, noanno=True) reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) diff = difflib.context_diff(nodes_str.split('\n'), reparsed_nodes_str.split('\n'), fromfile='Original nodes', tofile='Reparsed nodes', n=7) diff = '\n'.join(diff) logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff) raise return result
def create_source_map(nodes, code, filepath): """Creates a source map between an annotated AST and the code it compiles to. Note: this function assumes nodes nodes, code and filepath correspond to the same code. Args: nodes: Iterable[ast.AST, ...], one or more AST modes. code: Text, the source code in which nodes are found. filepath: Text Returns: Dict[LineLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False) for node in reparsed_nodes: resolve(node, code, filepath, node.lineno, node.col_offset) source_map = {} try: for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue # Note: the keys are by line only, excluding the column offset. line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno) existing_origin = source_map.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of column overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue source_map[line_loc] = origin_info except ValueError: if logging.has_verbosity(3): for n, rn in zip(nodes, reparsed_nodes): nodes_str = pretty_printer.fmt(n, color=False, noanno=True) reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) diff = difflib.context_diff( nodes_str.split('\n'), reparsed_nodes_str.split('\n'), fromfile='Original nodes', tofile='Reparsed nodes', n=7) diff = '\n'.join(diff) logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff) raise return source_map
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) source_code = self.entity_info.source_code source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) processing_expr_node = False try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True elif isinstance(node, gast.Expr): processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset if processing_expr_node: entry_expr_value = node.value if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if processing_expr_node: if isinstance(result, gast.Expr) and result.value != entry_expr_value: # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() if local_scope_size_at_entry != len(self._local_scope_state): raise AssertionError( 'Inconsistent local scope stack. Before entering node %s, the' ' stack had length %d, after exit it has length %d. This' ' indicates enter_local_scope and exit_local_scope are not' ' well paired.' % (node, local_scope_size_at_entry, len(self._local_scope_state))) return result except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), self._get_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '<no source available>' # TODO(mdan): Avoid the printing of the original exception. # In other words, we need to find how to suppress the "During handling # of the above exception, another exception occurred" message. six.reraise( AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2])
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) if logging.has_verbosity(4): for n in node: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr( candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns