Пример #1
0
    def test_converted_call_method_converts_recursively(self):
        class TestClass(object):
            def __init__(self, x):
                self.x = x

            def other_method(self):
                if self.x < 0:
                    return -self.x
                return self.x

            def test_method(self):
                return self.other_method()

        tc = TestClass(constant_op.constant(-1))
        x = api.converted_call(tc.test_method, None,
                               converter.ConversionOptions(recursive=True), (),
                               {})
        self.assertEqual(1, self.evaluate(x))
Пример #2
0
  def test_converted_call_synthetic_method(self):

    class TestClass(object):

      def __init__(self, x):
        self.x = x

    def test_function(self):
      if self.x < 0:
        return -self.x
      return self.x

    tc = TestClass(constant_op.constant(-1))
    test_method = types.MethodType(test_function, tc)

    x = api.converted_call(test_method,
                           converter.ConversionOptions(recursive=True), (), {})
    self.assertEqual(1, self.evaluate(x))
Пример #3
0
    def test_converted_call_constructor(self):
        class TestClass(object):
            def __init__(self, x):
                self.x = x

            def test_method(self):
                if self.x < 0:
                    return -self.x
                return self.x

        tc = api.converted_call(TestClass, None, converter.ConversionOptions(),
                                (constant_op.constant(-1), ), {})
        # tc is still a TestClass - constructors are whitelisted.
        # TODO(b/124016764): Support this use case.
        # The error below is specific to the `if` statement not being converted.
        with self.assertRaisesRegex(TypeError,
                                    'Using a `tf.Tensor` as a Python `bool`'):
            tc.test_method()
Пример #4
0
    def test_converted_call_no_user_code(self):
        def f(x):
            return len(x)

        opts = converter.ConversionOptions(internal_convert_user_code=False)

        # f should not be converted, causing len to error out.
        with self.assertRaisesRegexp(Exception, 'len is not well defined'):
            api.converted_call(f, (constant_op.constant([0]), ),
                               None,
                               options=opts)

        # len on the other hand should work fine.
        x = api.converted_call(len, (constant_op.constant([0]), ),
                               None,
                               options=opts)
        # The constant has static shape so the result is a primitive not a Tensor.
        self.assertEqual(x, 1)
Пример #5
0
  def prepare(self, test_fn, namespace, recursive=True):
    namespace['ConversionOptions'] = converter.ConversionOptions

    future_features = ('print_function', 'division')
    node, source = parser.parse_entity(test_fn, future_features=future_features)
    namer = naming.Namer(namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive),
        autograph_module=None)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve_entity(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
Пример #6
0
    def test_converted_call_callable_metaclass(self):
        class TestMetaclass(type):

            x = constant_op.constant(-1)

            def __call__(cls):
                if cls.x < 0:
                    cls.x = -cls.x
                return cls

        tc = TestMetaclass('TestClass', (), {})
        # This functools.partial will hide the class form the constructor
        # check. Not ideal. See b/120224672.
        tc = functools.partial(tc)
        converted_tc = api.converted_call(
            tc, converter.ConversionOptions(recursive=True), (), {})
        self.assertIsInstance(converted_tc, TestMetaclass)
        self.assertEqual(1, self.evaluate(converted_tc.x))
  def test_converted_call_constructor(self):

    class TestClass(object):

      def __init__(self, x):
        self.x = x

      def test_method(self):
        if self.x < 0:
          return -self.x
        return self.x

    with self.cached_session() as sess:
      tc = api.converted_call(TestClass, None, converter.ConversionOptions(),
                              constant_op.constant(-1))
      # tc is now a converted object.
      x = tc.test_method()
      self.assertEqual(1, self.evaluate(x))
    def test_super_with_two_args(self):
        test_case_self = self

        class TestBase(object):
            def plus_three(self, x):
                return x + 3

        class TestSubclass(TestBase):
            def plus_three(self, x):
                test_case_self.fail('This should never be called.')

            def two_args(self, x):
                return super(TestSubclass, self).plus_three(x)

        tc = api.converted_call(TestSubclass,
                                converter.ConversionOptions(recursive=True),
                                (), {})

        self.assertEqual(5, tc.two_args(2))
Пример #9
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    The converted code as string.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            optional_features=experimental_optional_features),
        autograph_module=tf_inspect.getmodule(to_graph))
    nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values,
                                             arg_types)

    code = compiler.ast_to_source(nodes, indentation)

    return program_ctx.required_imports + '\n\n' + code
Пример #10
0
    def test_converted_call_method_as_object_attribute(self):
        class AnotherClass(object):
            def __init__(self):
                self.another_class_attr = constant_op.constant(1)

            def method(self):
                if self.another_class_attr > 0:
                    return self.another_class_attr + 1
                return self.another_class_attr + 10

        class TestClass(object):
            def __init__(self, another_obj_method):
                self.another_obj_method = another_obj_method

        obj = AnotherClass()
        tc = TestClass(obj.method)

        x = api.converted_call('another_obj_method', tc,
                               converter.ConversionOptions(), (), {})
        self.assertEqual(self.evaluate(x), 2)
Пример #11
0
    def test_converted_call_defun_object_method(self):

        opts = converter.ConversionOptions(recursive=True)

        # pylint:disable=method-hidden
        class TestClass(object):
            def method(self):
                return 1

            def prepare(self):
                self.method = function.defun(self.method)

        # pylint:enable=method-hidden

        tc = TestClass()
        tc.prepare()

        x = api.converted_call(tc.method, None, opts, (), {})

        self.assertAllEqual(1, self.evaluate(x))
    def test_super_with_one_arg(self):
        test_case_self = self

        class TestBase(object):
            def plus_three(self, x):
                return x + 3

        class TestSubclass(TestBase):
            def plus_three(self, x):
                test_case_self.fail('This should never be called.')

            def one_arg(self, x):
                test_base_unbound = super(TestSubclass)
                test_base = test_base_unbound.__get__(self, TestSubclass)
                return test_base.plus_three(x)

        tc = api.converted_call(TestSubclass,
                                converter.ConversionOptions(recursive=True),
                                (), {})

        self.assertEqual(5, tc.one_arg(2))
    def test_to_ast(self):
        opts = converter.ConversionOptions()
        opts_ast = opts.to_ast()

        template = '''
    def test_fn():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _, _ = compiler.ast_to_object(opts_packed)
        reparsed.__dict__['ag__'] = self.make_fake_mod(
            'fake_ag', converter.ConversionOptions, converter.Feature)

        reparsed_opts = reparsed.test_fn()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.user_requested, False)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
Пример #14
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive,
                                            strip_decorators=(convert,
                                                              do_not_convert,
                                                              converted_call)),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
        for dep in reversed(program_ctx.conversion_order))

    return program_ctx.required_imports + '\n\n' + code
Пример #15
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()
        opts_ast = opts.to_ast()

        template = '''
    def f():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _, _ = loader.load_ast(opts_packed)
        fake_ag = imp.new_module('fake_ag')
        fake_ag.ConversionOptions = converter.ConversionOptions
        fake_ag.Feature = converter.Feature
        reparsed.ag__ = fake_ag

        reparsed_opts = reparsed.f()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.user_requested, False)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
Пример #16
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()

        namer = converter_testing.FakeNamer()
        program_ctx = converter.ProgramContext(options=opts,
                                               partial_types=None,
                                               autograph_module=None,
                                               uncompiled_modules=())
        entity_info = transformer.EntityInfo(source_code='',
                                             source_file='<fragment>',
                                             namespace={},
                                             arg_values=None,
                                             arg_types={},
                                             owner_type=None)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        opts_ast = opts.to_ast(ctx)

        template = '''
    def test_fn():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _ = compiler.ast_to_object(opts_packed)
        reparsed.__dict__['ag__'] = self.make_fake_mod(
            'fake_ag', converter.ConversionOptions, converter.Feature)

        reparsed_opts = reparsed.test_fn()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.verbose, reparsed_opts.verbose)
        self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
Пример #17
0
 def test_method(self, x, s, a):
     while tf.reduce_sum(x) > s:
         x //= api.converted_call(
             self.called_member, None,
             converter.ConversionOptions(recursive=True), (a, ), {})
     return x
Пример #18
0
 def _basic_function_scope(self):
     return function_wrappers.FunctionScope(
         'test_function_name',
         'test_scope',  # Note: this must match the name in the `with` statement.
         converter.ConversionOptions())
Пример #19
0
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect

tf = utils.fake_tf()

global_n = 2

DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True)


class TestResource(object):

  def __init__(self):
    self.x = 3


class ApiTest(test.TestCase):

  @contextlib.contextmanager
  def assertPrints(self, expected, not_expected):
    try:
      out_capturer = six.StringIO()
      sys.stdout = out_capturer
Пример #20
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
Пример #21
0
def to_graph(entity, recursive=True, experimental_optional_features=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of optional
      features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        return conversion.convert(entity, program_ctx)
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        logging.error(1, 'Error converting %s', entity, exc_info=True)
        raise ConversionError('converting {}: {}: {}'.format(
            entity, e.__class__.__name__, str(e)))
Пример #22
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if experimental_strip_decorators is None:
        experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if tf_inspect.isfunction(entity):
        compiled.__defaults__ = entity.__defaults__

    if hasattr(compiled, '__globals__'):
        # Remove self to avoid circular references. This will probably only work
        # so long as the function is not reentrant.
        del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
Пример #23
0
def to_graph(entity, recursive=True, experimental_optional_features=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  Example usage:

  >>> def f(x):
  ...   if x > 0:
  ...     y = x * x
  ...   else:
  ...     y = -x
  ...   return y
  ...
  >>> converted_f = to_graph(f)
  >>> x = tf.constant(2)
  >>> converted_f(x)  # converted_foo is like a TensorFlow Op.
  <tf.Tensor: shape=(), dtype=int32, numpy=4>

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  For a tutorial, see the
  [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
  For more detailed information, see the
  [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            user_requested=True,
            optional_features=experimental_optional_features),
        autograph_module=tf_inspect.getmodule(to_graph))
    return autograph_artifact(conversion.convert(entity, program_ctx))
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    logging.error(1, 'Error converting %s', entity, exc_info=True)
    raise ConversionError('converting {}: {}: {}'.format(
        entity, e.__class__.__name__, str(e)))
Пример #24
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
      ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  if strip_decorators is None:
    strip_decorators = ()
  strip_decorators += (convert, do_not_convert, converted_call)

  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=verbose,
          strip_decorators=strip_decorators),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.conversion_order):
    nodes.extend(program_ctx.dependency_cache[dep])

  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  return compiled
Пример #25
0
 def graph_fn():
   opts = converter.ConversionOptions(recursive=True)
   ds = api.converted_call(f, None, opts, (), {})
   itr = iter(ds)
   return next(itr), next(itr), next(itr)
Пример #26
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        # TODO(b/129431421): Remove these args.
        del arg_values
        del arg_types
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        return conversion.convert(entity, program_ctx)
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        errors.report_internal_error(entity, e)
Пример #27
0
 def test_converted_call_builtin(self):
     x = api.converted_call(range, None, converter.ConversionOptions(),
                            (3, ), {})
     self.assertEqual((0, 1, 2), tuple(x))
Пример #28
0
 def _simple_program_ctx(self):
     return converter.ProgramContext(
         options=converter.ConversionOptions(recursive=True),
         autograph_module=api)
Пример #29
0
 def _simple_program_ctx(self):
     return converter.ProgramContext(
         options=converter.ConversionOptions(recursive=True),
         partial_types=(),
         autograph_module=api,
         uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
Пример #30
0
 def test_method(self, x, s, a):
     while tf.reduce_sum(x) > s:
         x //= api.converted_call(self.called_member, None,
                                  converter.ConversionOptions(),
                                  self, a)
     return x