def test_converted_call_method_converts_recursively(self): class TestClass(object): def __init__(self, x): self.x = x def other_method(self): if self.x < 0: return -self.x return self.x def test_method(self): return self.other_method() tc = TestClass(constant_op.constant(-1)) x = api.converted_call(tc.test_method, None, converter.ConversionOptions(recursive=True), (), {}) self.assertEqual(1, self.evaluate(x))
def test_converted_call_synthetic_method(self): class TestClass(object): def __init__(self, x): self.x = x def test_function(self): if self.x < 0: return -self.x return self.x tc = TestClass(constant_op.constant(-1)) test_method = types.MethodType(test_function, tc) x = api.converted_call(test_method, converter.ConversionOptions(recursive=True), (), {}) self.assertEqual(1, self.evaluate(x))
def test_converted_call_constructor(self): class TestClass(object): def __init__(self, x): self.x = x def test_method(self): if self.x < 0: return -self.x return self.x tc = api.converted_call(TestClass, None, converter.ConversionOptions(), (constant_op.constant(-1), ), {}) # tc is still a TestClass - constructors are whitelisted. # TODO(b/124016764): Support this use case. # The error below is specific to the `if` statement not being converted. with self.assertRaisesRegex(TypeError, 'Using a `tf.Tensor` as a Python `bool`'): tc.test_method()
def test_converted_call_no_user_code(self): def f(x): return len(x) opts = converter.ConversionOptions(internal_convert_user_code=False) # f should not be converted, causing len to error out. with self.assertRaisesRegexp(Exception, 'len is not well defined'): api.converted_call(f, (constant_op.constant([0]), ), None, options=opts) # len on the other hand should work fine. x = api.converted_call(len, (constant_op.constant([0]), ), None, options=opts) # The constant has static shape so the result is a primitive not a Tensor. self.assertEqual(x, 1)
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve_entity(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def test_converted_call_callable_metaclass(self): class TestMetaclass(type): x = constant_op.constant(-1) def __call__(cls): if cls.x < 0: cls.x = -cls.x return cls tc = TestMetaclass('TestClass', (), {}) # This functools.partial will hide the class form the constructor # check. Not ideal. See b/120224672. tc = functools.partial(tc) converted_tc = api.converted_call( tc, converter.ConversionOptions(recursive=True), (), {}) self.assertIsInstance(converted_tc, TestMetaclass) self.assertEqual(1, self.evaluate(converted_tc.x))
def test_converted_call_constructor(self): class TestClass(object): def __init__(self, x): self.x = x def test_method(self): if self.x < 0: return -self.x return self.x with self.cached_session() as sess: tc = api.converted_call(TestClass, None, converter.ConversionOptions(), constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, self.evaluate(x))
def test_super_with_two_args(self): test_case_self = self class TestBase(object): def plus_three(self, x): return x + 3 class TestSubclass(TestBase): def plus_three(self, x): test_case_self.fail('This should never be called.') def two_args(self, x): return super(TestSubclass, self).plus_three(x) tc = api.converted_call(TestSubclass, converter.ConversionOptions(recursive=True), (), {}) self.assertEqual(5, tc.two_args(2))
def to_code(entity, recursive=True, arg_values=None, arg_types=None, indentation=' ', experimental_optional_features=converter.Feature.ALL): """Similar to `to_graph`, but returns Python source code as a string. Also see: `tf.autograph.to_graph`. `to_graph` returns the Python source code that can be used to generate a TensorFlow graph that is functionally identical to the input Python code. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. indentation: The string to use for indenting. Typically two or four spaces, or just the tab character. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. Returns: The converted code as string. """ program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) code = compiler.ast_to_source(nodes, indentation) return program_ctx.required_imports + '\n\n' + code
def test_converted_call_method_as_object_attribute(self): class AnotherClass(object): def __init__(self): self.another_class_attr = constant_op.constant(1) def method(self): if self.another_class_attr > 0: return self.another_class_attr + 1 return self.another_class_attr + 10 class TestClass(object): def __init__(self, another_obj_method): self.another_obj_method = another_obj_method obj = AnotherClass() tc = TestClass(obj.method) x = api.converted_call('another_obj_method', tc, converter.ConversionOptions(), (), {}) self.assertEqual(self.evaluate(x), 2)
def test_converted_call_defun_object_method(self): opts = converter.ConversionOptions(recursive=True) # pylint:disable=method-hidden class TestClass(object): def method(self): return 1 def prepare(self): self.method = function.defun(self.method) # pylint:enable=method-hidden tc = TestClass() tc.prepare() x = api.converted_call(tc.method, None, opts, (), {}) self.assertAllEqual(1, self.evaluate(x))
def test_super_with_one_arg(self): test_case_self = self class TestBase(object): def plus_three(self, x): return x + 3 class TestSubclass(TestBase): def plus_three(self, x): test_case_self.fail('This should never be called.') def one_arg(self, x): test_base_unbound = super(TestSubclass) test_base = test_base_unbound.__get__(self, TestSubclass) return test_base.plus_three(x) tc = api.converted_call(TestSubclass, converter.ConversionOptions(recursive=True), (), {}) self.assertEqual(5, tc.one_arg(2))
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def to_code(e, recursive=True, arg_values=None, arg_types=None, partial_types=None, indentation=' '): """Returns the equivalent code that uses TensorFlow ops. Also see: `to_graph`, `convert` Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. indentation: Text, when to use for each level of indentation. Returns: Text, the converted code. """ program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive, strip_decorators=(convert, do_not_convert, converted_call)), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation) for dep in reversed(program_ctx.conversion_order)) return program_ctx.required_imports + '\n\n' + code
def test_to_ast(self): opts = converter.ConversionOptions() opts_ast = opts.to_ast() template = ''' def f(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = loader.load_ast(opts_packed) fake_ag = imp.new_module('fake_ag') fake_ag.ConversionOptions = converter.ConversionOptions fake_ag.Feature = converter.Feature reparsed.ag__ = fake_ag reparsed_opts = reparsed.f() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_to_ast(self): opts = converter.ConversionOptions() namer = converter_testing.FakeNamer() program_ctx = converter.ProgramContext(options=opts, partial_types=None, autograph_module=None, uncompiled_modules=()) entity_info = transformer.EntityInfo(source_code='', source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) ctx = converter.EntityContext(namer, entity_info, program_ctx) opts_ast = opts.to_ast(ctx) template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.verbose, reparsed_opts.verbose) self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def test_method(self, x, s, a): while tf.reduce_sum(x) > s: x //= api.converted_call( self.called_member, None, converter.ConversionOptions(recursive=True), (a, ), {}) return x
def _basic_function_scope(self): return function_wrappers.FunctionScope( 'test_function_name', 'test_scope', # Note: this must match the name in the `with` statement. converter.ConversionOptions())
from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.layers import core from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.util import function_utils from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect tf = utils.fake_tf() global_n = 2 DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True) class TestResource(object): def __init__(self): self.x = 3 class ApiTest(test.TestCase): @contextlib.contextmanager def assertPrints(self, expected, not_expected): try: out_capturer = six.StringIO() sys.stdout = out_capturer
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx
def to_graph(entity, recursive=True, experimental_optional_features=None): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ try: program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) return conversion.convert(entity, program_ctx) except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: logging.error(1, 'Error converting %s', entity, exc_info=True) raise ConversionError('converting {}: {}: {}'.format( entity, e.__class__.__name__, str(e)))
def to_graph(entity, recursive=True, arg_values=None, arg_types=None, experimental_optional_features=converter.Feature.ALL, experimental_strip_decorators=None, experimental_verbose=converter.Verbosity.BRIEF, experimental_partial_types=None): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. experimental_strip_decorators: A tuple specifying decorators that should be excluded from the compiled output. By default, when converting a function before the decorators are applied, the compiled output will include those decorators. experimental_verbose: The level of printing verbosity to use, as a `tf.autograph.experimental.Verbosity` value. experimental_partial_types: A `set` of `type` values, reserved for internal use. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ if experimental_strip_decorators is None: experimental_strip_decorators = () experimental_strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=experimental_verbose, strip_decorators=experimental_strip_decorators, optional_features=experimental_optional_features), partial_types=experimental_partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, _ = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val for key, val in program_ctx.additional_symbols.items(): if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) if tf_inspect.isfunction(entity): compiled.__defaults__ = entity.__defaults__ if hasattr(compiled, '__globals__'): # Remove self to avoid circular references. This will probably only work # so long as the function is not reentrant. del compiled.__globals__[name] # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled
def to_graph(entity, recursive=True, experimental_optional_features=None): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. Example usage: >>> def f(x): ... if x > 0: ... y = x * x ... else: ... y = -x ... return y ... >>> converted_f = to_graph(f) >>> x = tf.constant(2) >>> converted_f(x) # converted_foo is like a TensorFlow Op. <tf.Tensor: shape=(), dtype=int32, numpy=4> Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. For a tutorial, see the [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). For more detailed information, see the [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md). Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ try: program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, user_requested=True, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) return autograph_artifact(conversion.convert(entity, program_ctx)) except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: logging.error(1, 'Error converting %s', entity, exc_info=True) raise ConversionError('converting {}: {}: {}'.format( entity, e.__class__.__name__, str(e)))
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None, strip_decorators=None): """Converts a Python entity into equivalent code that uses TensorFlow ops. Supported Python entities include: * functions * classes Classes are converted by converting all their methods into a new class. Args: e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. verbose: bool, whether to output the compiled code in the logs. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including function arguments. partial_types: Set[Type], reserved for internal use. strip_decorators: Tuple[Callable], same as ConversionOptions.strip_decorators. Returns: Union[Callable, Type], the converted entity, which is the same kind as e (that is, a function is e is a function, a class if e is a class, etc.) but its code has been converted to use TF ops. Raises: ValueError: If the entity could not be converted. """ if strip_decorators is None: strip_decorators = () strip_decorators += (convert, do_not_convert, converted_call) program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, verbose=verbose, strip_decorators=strip_decorators), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) nodes = [] for dep in reversed(program_ctx.conversion_order): nodes.extend(program_ctx.dependency_cache[dep]) compiled_module, compiled_src = compiler.ast_to_object( nodes, source_prefix=program_ctx.required_imports, include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. # TODO(mdan): Record this statically in the generated code. # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % (compiled, source_map_attribute_name)) setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) return compiled
def graph_fn(): opts = converter.ConversionOptions(recursive=True) ds = api.converted_call(f, None, opts, (), {}) itr = iter(ds) return next(itr), next(itr), next(itr)
def to_graph(entity, recursive=True, arg_values=None, arg_types=None, experimental_optional_features=converter.Feature.ALL): """Converts a Python entity into a TensorFlow graph. Also see: `tf.autograph.to_code`, `tf.function`. Unlike `tf.function`, `to_graph` is a low-level transpiler that converts Python code to TensorFlow graph code. It does not implement any caching, variable management or create any actual ops, and is best used where greater control over the generated TensorFlow graph is desired. Another difference from `tf.function` is that `to_graph` will not wrap the graph into a TensorFlow function or a Python callable. Internally, `tf.function` uses `to_graph`. _Example Usage_ ```python def foo(x): if x > 0: y = x * x else: y = -x return y converted_foo = to_graph(foo) x = tf.constant(1) y = converted_foo(x) # converted_foo is a TensorFlow Op-like. assert is_tensor(y) ``` Supported Python entities include: * functions * classes * object methods Functions are converted into new functions with converted code. Classes are converted by generating a new class whose methods use converted code. Methods are converted into unbound function that have an additional first argument called `self`. Args: entity: Python callable or class to convert. recursive: Whether to recursively convert any functions that the converted function may call. arg_values: Optional dict of value hints for symbols including function arguments mapping string names to actual values. For example, `arg_values={'a': 1}` will map the variable `a` to the value `1`. arg_types: Optional dict of type hints for symbols including function arguments. Type hints allow specifying just the type of a variable, rather than a specific value. experimental_optional_features: `None`, a tuple of, or a single `tf.autograph.experimental.Feature` value. Controls the use of optional features in the conversion process. Returns: Same as `entity`, the converted Python function or class. Raises: ValueError: If the entity could not be converted. """ try: # TODO(b/129431421): Remove these args. del arg_values del arg_types program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) return conversion.convert(entity, program_ctx) except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: errors.report_internal_error(entity, e)
def test_converted_call_builtin(self): x = api.converted_call(range, None, converter.ConversionOptions(), (3, ), {}) self.assertEqual((0, 1, 2), tuple(x))
def _simple_program_ctx(self): return converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=api)
def _simple_program_ctx(self): return converter.ProgramContext( options=converter.ConversionOptions(recursive=True), partial_types=(), autograph_module=api, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
def test_method(self, x, s, a): while tf.reduce_sum(x) > s: x //= api.converted_call(self.called_member, None, converter.ConversionOptions(), self, a) return x