def assert_same_ast(self, expected_node, node, msg=None): expected_source = compiler.ast_to_source(expected_node, indentation=' ') expected_str = textwrap.dedent(expected_source).strip() got_source = compiler.ast_to_source(node, indentation=' ') got_str = textwrap.dedent(got_source).strip() self.assertEqual(expected_str, got_str, msg=msg)
def __repr__(self): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.ClassDef): return 'class %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): return compiler.ast_to_source(self.ast_node.context_expr).strip() return compiler.ast_to_source(self.ast_node).strip()
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(other_value).strip(), compiler.ast_to_source(first_value).strip())) return first_value
def assert_body_anfs_as_expected(self, expected_fn, test_fn): # Testing the code bodies only. Wrapping them in functions so the # syntax highlights nicely, but Python doesn't try to execute the # statements. node, _ = parser.parse_entity(test_fn, future_features=()) orig_source = compiler.ast_to_source(node, indentation=' ') orig_str = textwrap.dedent(orig_source).strip() config = [(anf.ANY, anf.LEAVE)] # Configuration to trasform nothing node = anf.transform( node, self._simple_context(), config=config, gensym_source=DummyGensym) new_source = compiler.ast_to_source(node, indentation=' ') new_str = textwrap.dedent(new_source).strip() self.assertEqual(orig_str, new_str)
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(other_value).strip(), compiler.ast_to_source(first_value).strip())) return first_value
def assertLambdaNodes(self, matching_nodes, expected_bodies): self.assertEqual(len(matching_nodes), len(expected_bodies)) for node in matching_nodes: self.assertIsInstance(node, gast.Lambda) self.assertIn( compiler.ast_to_source(node.body, include_encoding_marker=False).strip(), expected_bodies)
def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): nodes, name, entity_info = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'python/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def _get_source(self, node): try: source, _ = compiler.ast_to_source(node) return source # pylint: disable=broad-except # This function is used for error reporting. If an exception occurs here, # it should be suppressed, in favor of emitting as informative a message # about the original error as possible. except Exception: return '<could not convert AST to source>'
def test_rename_symbols_basic(self): node = parser.parse_str('a + b') node = qual_names.resolve(node) node = ast_util.rename_symbols( node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) source = compiler.ast_to_source(node) self.assertEqual(source.strip(), 'renamed_a + b')
def test_source_map_no_origin(self): def test_fn(x): return x + 1 node, _, _ = parser.parse_entity(test_fn) converted_code = compiler.ast_to_source(node) source_map = origin_info.create_source_map(node, converted_code, 'test_filename', [0]) self.assertEqual(len(source_map), 0)
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive for a symbol, or a default if none exist. See lang/directives.py for details on directives. Args: node: ast.AST directive: Callable[..., Any] arg: str default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default # TODO(mdan): Simplify this. arg_values = [] for def_ in defs: if (directive not in def_.directives or arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: if not ast_util.matches(arg_value, prev_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(arg_value).strip(), compiler.ast_to_source(prev_value).strip())) arg_values.append(arg_value) if not arg_values: return default arg_value, = arg_values return arg_value
def test_entity_to_graph_function_with_defaults(self): b = 2 c = 1 def f(a, d=c + 1): return a + b + d program_ctx = self._simple_program_ctx() nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None) fn_node, _ = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertEqual( compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
def test_source_map_no_origin(self): def test_fn(x): return x + 1 node, _ = parser.parse_entity(test_fn) fn_node = node.body[0] converted_code = compiler.ast_to_source(fn_node) source_map = origin_info.create_source_map( fn_node, converted_code, 'test_filename', [0]) self.assertEqual(len(source_map), 0)
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive for a symbol, or a default if none exist. See lang/directives.py for details on directives. Args: node: ast.AST directive: Callable[..., Any] arg: str default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default # TODO(mdan): Simplify this. arg_values = [] for def_ in defs: if (directive not in def_.directives or arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: if not ast_util.matches(arg_value, prev_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, compiler.ast_to_source(arg_value).strip(), compiler.ast_to_source(prev_value).strip())) arg_values.append(arg_value) if not arg_values: return default arg_value, = arg_values return arg_value
def to_code(entity, recursive=True, arg_values=None, arg_types=None, indentation=' ', experimental_optional_features=converter.Feature.ALL, experimental_partial_types=None): """Similar to `to_graph`, but returns Python source code as a string. Also see: `tf.autograph.to_graph`. `to_graph` returns the Python source code that can be used to generate a TensorFlow graph that is functionally identical to the input Python code. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. indentation: The string to use for indenting. Typically two or four spaces, or just the tab character. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. experimental_partial_types: A `set` of `type` values, reserved for internal use. Returns: The converted code as string. """ program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=converter.Verbosity.BRIEF, strip_decorators=(convert, do_not_convert, converted_call), optional_features=experimental_optional_features), partial_types=experimental_partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation) for dep in reversed(program_ctx.conversion_order)) return program_ctx.required_imports + '\n\n' + code
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = convert_class_to_ast(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif tf_inspect.ismethod(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'python/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def test_convert_entity_to_ast_function_with_defaults(self): b = 2 c = 1 def f(a, d=c + 1): return a + b + d program_ctx = self._simple_program_ctx() nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx) fn_node, = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertEqual( compiler.ast_to_source(fn_node.args.defaults[0], include_encoding_marker=False).strip(), 'None')
def to_code(entity, recursive=True, arg_values=None, arg_types=None, indentation=' ', experimental_optional_features=converter.Feature.ALL, experimental_partial_types=None): """Similar to `to_graph`, but returns Python source code as a string. Also see: `tf.autograph.to_graph`. `to_graph` returns the Python source code that can be used to generate a TensorFlow graph that is functionally identical to the input Python code. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. indentation: The string to use for indenting. Typically two or four spaces, or just the tab character. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. experimental_partial_types: Deprecated, unused. Returns: The converted code as string. """ del experimental_partial_types program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) code = compiler.ast_to_source(nodes, indentation) return program_ctx.required_imports + '\n\n' + code
def convert_entity_to_ast(o, program_ctx): """Compile a Python entity into equivalent TensorFlow. Args: o: A Python entity. program_ctx: A ProgramContext object. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): nodes, name, entity_info = convert_class_to_ast(o, program_ctx) elif tf_inspect.isfunction(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif tf_inspect.ismethod(o): nodes, name, entity_info = convert_func_to_ast(o, program_ctx) elif hasattr(o, '__class__'): # Note: this should only be raised when attempting to convert the object # directly. converted_call should still support it. raise NotImplementedError( 'cannot convert entity "{}": object conversion is not yet' ' supported.'.format(o)) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(nodes)) if logging.has_verbosity(4): for n in nodes: logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, pretty_printer.fmt(n, color=False)) return nodes, name, entity_info
def test_create_source_map(self): def test_fn(x): return x + 1 node, _, _ = parser.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin) converted_code = compiler.ast_to_source(node) source_map = origin_info.create_source_map(node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)
def test_source_map(self): def test_fn(x): if x > 0: x += 1 return x node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) # Insert a traced line. new_node = parser.parse_str('x = abs(x)').body[0] anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) fn_node.body.insert(0, new_node) # Insert an untraced line. fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) modified_source = compiler.ast_to_source(fn_node) source_map = origin_info.source_map(fn_node, modified_source, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 1) origin = source_map[loc] self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertEqual(origin.loc.lineno, 1) # The untraced line, inserted second. loc = origin_info.LineLocation('test_filename', 2) self.assertFalse(loc in source_map) # The traced line, inserted first. loc = origin_info.LineLocation('test_filename', 3) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2) loc = origin_info.LineLocation('test_filename', 4) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2)
def test_ast_to_source(self): node = gast.If( test=gast.Num(1), body=[ gast.Assign(targets=[gast.Name('a', gast.Store(), None)], value=gast.Name('b', gast.Load(), None)) ], orelse=[ gast.Assign(targets=[gast.Name('a', gast.Store(), None)], value=gast.Str('c')) ]) source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' """).strip(), source.strip())
def test_create_source_map(self): def test_fn(x): return x + 1 node, _ = parser.parse_entity(test_fn) fake_origin = origin_info.OriginInfo( loc=origin_info.Location('fake_filename', 3, 7), function_name='fake_function_name', source_code_line='fake source line', comment=None) fn_node = node.body[0] anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin) converted_code = compiler.ast_to_source(fn_node) source_map = origin_info.create_source_map( fn_node, converted_code, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 2) self.assertIn(loc, source_map) self.assertIs(source_map[loc], fake_origin)
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive, strip_decorators=(convert, do_not_convert, converted_call)), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation) for dep in reversed(program_ctx.conversion_order)) return program_ctx.required_imports + '\n\n' + code
def test_ast_to_source(self): node = gast.If( test=gast.Num(1), body=[ gast.Assign( targets=[gast.Name('a', gast.Store(), None)], value=gast.Name('b', gast.Load(), None)) ], orelse=[ gast.Assign( targets=[gast.Name('a', gast.Store(), None)], value=gast.Str('c')) ]) source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' """).strip(), source.strip())
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=(convert, do_not_convert, converted_call)), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation) for dep in reversed(program_ctx.conversion_order)) return program_ctx.required_imports + '\n\n' + code
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr( candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ if program_ctx.options.verbose == converter.Verbosity.VERBOSE: logging.info('Converting {}'.format(o)) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if program_ctx.options.verbose == converter.Verbosity.VERBOSE: logging.info('Compiled output of {}:\n\n{}\n'.format( o, compiler.ast_to_source(node))) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns
def debug_print_src(self, node): """Helper method useful for debugging. Prints the AST as code.""" if __debug__: print(compiler.ast_to_source(node)) return node
def _mock_apply_fn(self, target, source): target = compiler.ast_to_source(target, include_encoding_marker=False) source = compiler.ast_to_source(source, include_encoding_marker=False) self._invocation_counts[(target.strip(), source.strip())] += 1
def assertFunctionDefNodes(self, matching_nodes, expected_bodies): self.assertEqual(len(matching_nodes), len(expected_bodies)) for node in matching_nodes: self.assertIsInstance(node, gast.FunctionDef) self.assertIn( compiler.ast_to_source(node.body).strip(), expected_bodies)
def assertFunctionDefNodes(self, matching_nodes, expected_bodies): self.assertEqual(len(matching_nodes), len(expected_bodies)) for node in matching_nodes: self.assertIsInstance(node, gast.FunctionDef) self.assertIn(compiler.ast_to_source(node.body).strip(), expected_bodies)
def _mock_apply_fn(self, target, source): target = compiler.ast_to_source(target) source = compiler.ast_to_source(source) self._invocation_counts[(target.strip(), source.strip())] += 1
def __repr__(self): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): return compiler.ast_to_source(self.ast_node.context_expr).strip() return compiler.ast_to_source(self.ast_node).strip()
def assertTransformedFirstLineIs(self, node, expected): self.assertEqual( compiler.ast_to_source( node, include_encoding_marker=False).split('\n')[0], expected)
def assertTransformedFirstLineIs(self, node, expected): self.assertEqual(compiler.ast_to_source(node).split('\n')[0], expected)