def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed( tuple(six.itervalues(conversion_map.dependency_cache)))) return imports + '\n\n' + code
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def test_entity_to_graph_class_hierarchy(self): class TestBase(object): def __init__(self, x='base'): self.x = x def foo(self): return self.x def bar(self): return self.x class TestSubclass(TestBase): def __init__(self, y): super(TestSubclass, self).__init__('sub') self.y = y def foo(self): return self.y def baz(self): return self.y conversion_map = self._simple_conversion_map() conversion.entity_to_graph(TestSubclass, conversion_map, None, None) self.assertTrue(TestBase in conversion_map.dependency_cache) self.assertTrue(TestSubclass in conversion_map.dependency_cache) self.assertEqual('TfTestBase', conversion_map.dependency_cache[TestBase].body[-1].name) self.assertEqual( 'TfTestSubclass', conversion_map.dependency_cache[TestSubclass].body[-1].name)
def test_entity_to_graph_class_hierarchy(self): class TestBase(object): def __init__(self, x='base'): self.x = x def foo(self): return self.x def bar(self): return self.x class TestSubclass(TestBase): def __init__(self, y): super(TestSubclass, self).__init__('sub') self.y = y def foo(self): return self.y def baz(self): return self.y program_ctx = self._simple_program_ctx() conversion.entity_to_graph(TestSubclass, program_ctx, None, None) self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertEqual('TfTestBase', program_ctx.dependency_cache[TestBase].body[-1].name) self.assertEqual( 'TfTestSubclass', program_ctx.dependency_cache[TestSubclass].body[-1].name)
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed(tuple( six.itervalues(conversion_map.dependency_cache)))) return imports + '\n\n' + code
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Return the equivalent of an entity in TensorFlow code. See `to_graph` for more details. Args: e: A Python entity. recursive: See to_graph. arg_values: See to_graph. arg_types: See to_graph. partial_types: See to_graph. indentation: String, when to use for each level of indentation. Returns: String. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def test_ag_module_cached(self): def callee(): return range(3) def caller(a): return a() conversion_map = self._simple_conversion_map() _, _, callee_ns = conversion.entity_to_graph( callee, conversion_map, None, None) _, _, caller_ns = conversion.entity_to_graph( caller, conversion_map, None, None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
def test_ag_module_cached(self): def callee(): return range(3) def caller(a): return a() program_ctx = self._simple_program_ctx() _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None, None) _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None, None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): for key, val in inspect_utils.getnamespace(e).items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) f_node = program_ctx.dependency_cache[f][0] g_node = program_ctx.dependency_cache[g][0] self.assertEqual('tf__f', f_node.name) self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) self.assertEqual('tf__g', g_node.name)
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object( module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): for key, val in inspect_utils.getnamespace(e).items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) f_node = program_ctx.dependency_cache[f][0] g_node = program_ctx.dependency_cache[g][0] self.assertEqual('tf__f', f_node.name) self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) self.assertEqual('tf__g', g_node.name)
def test_entity_to_graph_callable(self): def f(a): return a conversion_map = conversion.ConversionMap(True, (), (), None) ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name)
def test_entity_to_graph_callable(self): def f(a): return a conversion_map = conversion.ConversionMap(True, (), (), None) ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name)
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b conversion_map = self._simple_conversion_map() ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) self.assertTrue(ns['b'] is b)
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) self.assertTrue(g in conversion_map.dependency_cache) self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( 'tf__g', conversion_map.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self): class TestSubclass(training.Model): def __init__(self, y): super(TestSubclass, self).__init__() self.built = False def call(self, x): return 3 * x program_ctx = self._simple_program_ctx() conversion.entity_to_graph(TestSubclass, program_ctx, None, None) self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) self.assertEqual('TfTestSubclass', program_ctx.dependency_cache[TestSubclass][-1].name)
def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b conversion_map = self._simple_conversion_map() ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) self.assertTrue(ns['b'] is b)
def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) fn_node, _ = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertIs(ns['b'], b)
def test_entity_to_graph_call_tree(self): def g(a): return a def f(a): return g(a) conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) self.assertTrue(g in conversion_map.dependency_cache) self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( 'tf__g', conversion_map.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self): class TestSubclass(training.Model): def __init__(self, y): super(TestSubclass, self).__init__() self.built = False def call(self, x): return 3 * x program_ctx = self._simple_program_ctx() conversion.entity_to_graph(TestSubclass, program_ctx, None, None) self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) self.assertEqual('TfTestSubclass', program_ctx.dependency_cache[TestSubclass].body[-1].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self): class TestSubclass(training.Model): def __init__(self, y): super(TestSubclass, self).__init__() self.built = False def call(self, x): return 3 * x program_ctx = self._simple_program_ctx() conversion.entity_to_graph(TestSubclass, program_ctx, None, None) self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) # The returned nodes will include: # <import nodes>, <class node>, <assignment node> self.assertEqual('TfTestSubclass', program_ctx.dependency_cache[TestSubclass][-2].name)
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(dep, indentation) for dep in reversed( tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code
def test_entity_to_graph_class_hierarchy(self): class TestBase(object): def __init__(self, x='base'): self.x = x def foo(self): return self.x def bar(self): return self.x class TestSubclass(TestBase): def __init__(self, y): super(TestSubclass, self).__init__('sub') self.y = y def foo(self): return self.y def baz(self): return self.y program_ctx = self._simple_program_ctx() conversion.entity_to_graph(TestSubclass, program_ctx, None, None) self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) # The returned nodes will include: # <import nodes>, <class node>, <assignment node> self.assertEqual('TfTestBase', program_ctx.dependency_cache[TestBase][-2].name) self.assertEqual('TfTestSubclass', program_ctx.dependency_cache[TestSubclass][-2].name)
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph('dummy', conversion_map, None, None)
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled_fn = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. source_map_attribute_name = 'ag_source_map' if getattr(compiled_fn, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled_fn, source_map_attribute_name)) setattr(compiled_fn, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = self._simple_conversion_map() conversion.entity_to_graph('dummy', conversion_map, None, None)
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph('dummy', conversion_map, None, None)
def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None)
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None)
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None)
def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph(f, program_ctx, None, None)
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. Raises: ValueError: If the converted function defines or refers to symbol names that are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): nodes.extend(dep) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled
def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = self._simple_conversion_map() conversion.entity_to_graph('dummy', conversion_map, None, None)
def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): conversion_map = self._simple_conversion_map() conversion.entity_to_graph(f, conversion_map, None, None)