Exemplo n.º 1
0
  def test_entity_to_graph_multiple_lambdas_ambiguous_definitions(self):
    a, b = 1, 2
    f, _ = (lambda x: a * x, lambda x: b * x)

    program_ctx = self._simple_program_ctx()
    with self.assertRaises(ValueError):
      conversion.entity_to_graph(f, program_ctx, None, None)
Exemplo n.º 2
0
    def test_entity_to_graph_class_hierarchy(self):
        class TestBase(object):
            def __init__(self, x='base'):
                self.x = x

            def foo(self):
                return self.x

            def bar(self):
                return self.x

        class TestSubclass(TestBase):
            def __init__(self, y):
                super(TestSubclass, self).__init__('sub')
                self.y = y

            def foo(self):
                return self.y

            def baz(self):
                return self.y

        program_ctx = self._simple_program_ctx()
        with self.assertRaisesRegex(NotImplementedError,
                                    'classes.*whitelisted'):
            conversion.entity_to_graph(TestSubclass, program_ctx, None, None)
Exemplo n.º 3
0
    def test_entity_to_graph_multiple_lambdas_ambiguous_definitions(self):
        a, b = 1, 2
        f, _ = (lambda x: a * x, lambda x: b * x)

        program_ctx = self._simple_program_ctx()
        with self.assertRaises(ValueError):
            conversion.entity_to_graph(f, program_ctx, None, None)
Exemplo n.º 4
0
    def test_entity_to_graph_class_hierarchy(self):
        class TestBase(object):
            def __init__(self, x='base'):
                self.x = x

            def foo(self):
                return self.x

            def bar(self):
                return self.x

        class TestSubclass(TestBase):
            def __init__(self, y):
                super(TestSubclass, self).__init__('sub')
                self.y = y

            def foo(self):
                return self.y

            def baz(self):
                return self.y

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

        self.assertTrue(TestBase in program_ctx.dependency_cache)
        self.assertTrue(TestSubclass in program_ctx.dependency_cache)
        # The returned nodes will include:
        # <import nodes>, <class node>, <assignment node>
        self.assertEqual('TfTestBase',
                         program_ctx.dependency_cache[TestBase][-2].name)
        self.assertEqual('TfTestSubclass',
                         program_ctx.dependency_cache[TestSubclass][-2].name)
Exemplo n.º 5
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
  """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    The converted code as string.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=converter.Verbosity.BRIEF,
          strip_decorators=(convert, do_not_convert, converted_call),
          optional_features=experimental_optional_features),
      partial_types=experimental_partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
Exemplo n.º 6
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    The converted code as string.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=converter.Verbosity.BRIEF,
            strip_decorators=(convert, do_not_convert, converted_call),
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
        for dep in reversed(program_ctx.conversion_order))

    return program_ctx.required_imports + '\n\n' + code
Exemplo n.º 7
0
  def test_ag_module_cached(self):
    def callee():
      return range(3)

    def caller(a):
      return a()

    program_ctx = self._simple_program_ctx()
    _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None,
                                                 None)
    _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None,
                                                 None)

    self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
Exemplo n.º 8
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(f, program_ctx, None, None)

        self.assertTrue(f in program_ctx.dependency_cache)
        self.assertFalse(g in program_ctx.dependency_cache)
        f_node = program_ctx.dependency_cache[f][0]
        self.assertEqual('tf__f', f_node.name)
Exemplo n.º 9
0
    def test_ag_module_cached(self):
        def callee():
            return range(3)

        def caller(a):
            return a()

        program_ctx = self._simple_program_ctx()
        _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None,
                                                     None)
        _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None,
                                                     None)

        self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
Exemplo n.º 10
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(f, program_ctx, None, None)

    self.assertTrue(f in program_ctx.dependency_cache)
    self.assertFalse(g in program_ctx.dependency_cache)
    f_node = program_ctx.dependency_cache[f][0]
    self.assertEqual('tf__f', f_node.name)
Exemplo n.º 11
0
  def test_entity_to_graph_multiple_lambdas(self):
    a, b = 1, 2
    f, _ = (lambda x: a * x, lambda y: b * y)

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
    self.assertIs(ns['a'], a)
Exemplo n.º 12
0
  def test_entity_to_graph_lambda(self):
    b = 2
    f = lambda x: b * x if x > 0 else -x

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
    self.assertIs(ns['b'], b)
Exemplo n.º 13
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        program_ctx = self._simple_program_ctx()
        nodes, _, _ = conversion.entity_to_graph(f, program_ctx, None, None)
        f_node, = nodes
        self.assertEqual('tf__f', f_node.name)
Exemplo n.º 14
0
    def test_entity_to_graph_multiple_lambdas(self):
        a, b = 1, 2
        f, _ = (lambda x: a * x, lambda y: b * y)

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.entity_to_graph(
            f, program_ctx, None, None)
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
        self.assertIs(entity_info.namespace['a'], a)
Exemplo n.º 15
0
    def test_entity_to_graph_lambda(self):
        b = 2
        f = lambda x: b * x if x > 0 else -x

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.entity_to_graph(
            f, program_ctx, None, None)
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
        self.assertIs(entity_info.namespace['b'], b)
Exemplo n.º 16
0
  def test_entity_to_graph_callable(self):
    b = 2
    def f(a):
      return a + b

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertIs(ns['b'], b)
Exemplo n.º 17
0
  def test_entity_to_graph_callable(self):
    b = 2
    def f(a):
      return a + b

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertIs(ns['b'], b)
Exemplo n.º 18
0
    def test_entity_to_graph_lambda(self):
        b = 2
        f = lambda x: b * x if x > 0 else -x

        program_ctx = self._simple_program_ctx()
        nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None,
                                                     None)
        fn_node = nodes[-2]
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
        self.assertIs(ns['b'], b)
Exemplo n.º 19
0
    def test_entity_to_graph_class_hierarchy_whitelisted(self):
        class TestSubclass(training.Model):
            def __init__(self, y):
                super(TestSubclass, self).__init__()
                self.built = False

            def call(self, x):
                return 3 * x

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

        self.assertTrue(TestSubclass in program_ctx.dependency_cache)
        self.assertFalse(training.Model in program_ctx.dependency_cache)
        self.assertEqual(
            'Model',
            program_ctx.dependency_cache[TestSubclass][0].names[0].name)
        # The returned nodes will include:
        # <import nodes>, <class node>, <assignment node>
        self.assertEqual('TfTestSubclass',
                         program_ctx.dependency_cache[TestSubclass][-2].name)
Exemplo n.º 20
0
  def test_entity_to_graph_function_with_defaults(self):
    b = 2
    c = 1
    def f(a, d=c + 1):
      return a + b + d

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertEqual(
        compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
Exemplo n.º 21
0
  def test_entity_to_graph_lambda_code_with_garbage(self):
    # pylint:disable=g-long-lambda
    f = (  # intentional wrap
        lambda x: (x  # intentional wrap
                   + 1),)[0]
    # pylint:enable=g-long-lambda

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
Exemplo n.º 22
0
  def test_entity_to_graph_function_with_defaults(self):
    b = 2
    c = 1
    def f(a, d=c + 1):
      return a + b + d

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertEqual(
        compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
Exemplo n.º 23
0
  def test_entity_to_graph_class_hierarchy_whitelisted(self):

    class TestSubclass(training.Model):

      def __init__(self, y):
        super(TestSubclass, self).__init__()
        self.built = False

      def call(self, x):
        return 3 * x

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

    self.assertTrue(TestSubclass in program_ctx.dependency_cache)
    self.assertFalse(training.Model in program_ctx.dependency_cache)
    self.assertEqual(
        'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
    # The returned nodes will include:
    # <import nodes>, <class node>, <assignment node>
    self.assertEqual('TfTestSubclass',
                     program_ctx.dependency_cache[TestSubclass][-2].name)
Exemplo n.º 24
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive,
                                            strip_decorators=(convert,
                                                              do_not_convert,
                                                              converted_call)),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
        for dep in reversed(program_ctx.conversion_order))

    return program_ctx.required_imports + '\n\n' + code
Exemplo n.º 25
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          strip_decorators=(convert, do_not_convert, converted_call)),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
Exemplo n.º 26
0
    def test_entity_to_graph_lambda_code_with_garbage(self):
        # pylint:disable=g-long-lambda
        f = (  # intentional wrap
            lambda x: (
                x  # intentional wrap
                + 1), )[0]
        # pylint:enable=g-long-lambda

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
        fn_node, _ = nodes
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
Exemplo n.º 27
0
  def test_entity_to_graph_nested_functions(self):
    b = 2

    def f(x):
      def g(x):
        return b * x
      return g(x)

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual(fn_node.name, 'tf__f')
    self.assertEqual('tf__f', name)
    self.assertIs(ns['b'], b)
Exemplo n.º 28
0
  def test_entity_to_graph_class_hierarchy(self):

    class TestBase(object):

      def __init__(self, x='base'):
        self.x = x

      def foo(self):
        return self.x

      def bar(self):
        return self.x

    class TestSubclass(TestBase):

      def __init__(self, y):
        super(TestSubclass, self).__init__('sub')
        self.y = y

      def foo(self):
        return self.y

      def baz(self):
        return self.y

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

    self.assertTrue(TestBase in program_ctx.dependency_cache)
    self.assertTrue(TestSubclass in program_ctx.dependency_cache)
    # The returned nodes will include:
    # <import nodes>, <class node>, <assignment node>
    self.assertEqual('TfTestBase',
                     program_ctx.dependency_cache[TestBase][-2].name)
    self.assertEqual('TfTestSubclass',
                     program_ctx.dependency_cache[TestSubclass][-2].name)
Exemplo n.º 29
0
    def test_entity_to_graph_class_hierarchy_whitelisted(self):
        class TestSubclass(training.Model):
            def __init__(self, y):
                super(TestSubclass, self).__init__()
                self.built = False

            def call(self, x):
                return 3 * x

        program_ctx = self._simple_program_ctx()
        (import_node, class_node), name, _ = conversion.entity_to_graph(
            TestSubclass, program_ctx, None, None)
        self.assertEqual(import_node.names[0].name, 'Model')
        self.assertEqual(name, 'TfTestSubclass')
        self.assertEqual(class_node.name, 'TfTestSubclass')
Exemplo n.º 30
0
    def test_entity_to_graph_nested_functions(self):
        b = 2

        def f(x):
            def g(x):
                return b * x

            return g(x)

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.entity_to_graph(
            f, program_ctx, None, None)
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual(fn_node.name, 'tf__f')
        self.assertEqual('tf__f', name)
        self.assertIs(entity_info.namespace['b'], b)
    def test_entity_to_graph_class_hierarchy_whitelisted(self):
        class TestSubclass(training.Model):
            def __init__(self, y):
                super(TestSubclass, self).__init__()
                self.built = False

            def call(self, x):
                return 3 * x

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.entity_to_graph(TestSubclass, program_ctx,
                                                    None, None)
        class_node = nodes[-2]  # TODO(mdan): This is brittle.

        self.assertEqual(name, 'TfTestSubclass')
        self.assertEqual(class_node.name, 'TfTestSubclass')
Exemplo n.º 32
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types:  Deprecated, unused.

  Returns:
    The converted code as string.
  """
    del experimental_partial_types

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            optional_features=experimental_optional_features),
        autograph_module=tf_inspect.getmodule(to_graph))
    nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values,
                                             arg_types)

    code = compiler.ast_to_source(nodes, indentation)

    return program_ctx.required_imports + '\n\n' + code
Exemplo n.º 33
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if experimental_strip_decorators is None:
        experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if tf_inspect.isfunction(entity):
        compiled.__defaults__ = entity.__defaults__

    if hasattr(compiled, '__globals__'):
        # Remove self to avoid circular references. This will probably only work
        # so long as the function is not reentrant.
        del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
Exemplo n.º 34
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
Exemplo n.º 35
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
      ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  if strip_decorators is None:
    strip_decorators = ()
  strip_decorators += (convert, do_not_convert, converted_call)

  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=verbose,
          strip_decorators=strip_decorators),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.conversion_order):
    nodes.extend(program_ctx.dependency_cache[dep])

  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  return compiled
Exemplo n.º 36
0
  def test_entity_to_graph_lambda(self):
    f = lambda a: a

    with self.assertRaises(NotImplementedError):
      program_ctx = self._simple_program_ctx()
      conversion.entity_to_graph(f, program_ctx, None, None)
Exemplo n.º 37
0
    def test_entity_to_graph_lambda(self):
        f = lambda a: a

        with self.assertRaises(NotImplementedError):
            program_ctx = self._simple_program_ctx()
            conversion.entity_to_graph(f, program_ctx, None, None)
Exemplo n.º 38
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        nodes, name, namespace = conversion.entity_to_graph(
            entity, program_ctx, arg_values, arg_types)

        compiled_module, _ = compiler.ast_to_object(
            nodes,
            source_prefix=program_ctx.required_imports,
            include_source_map=True)

        # The compiled code should see everything the entry entity saw.
        # TODO(mdan): This might not work well if the call tree spans modules?
        for key, val in namespace.items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_module.__dict__:
                compiled_module.__dict__[key] = val
        compiled = getattr(compiled_module, name)

        if hasattr(entity, '__defaults__'):
            logging.log(3, 'Default args mapping: %s has: %s', entity,
                        entity.__defaults__)
            compiled.__defaults__ = entity.__defaults__
        else:
            logging.log(3, 'Default args mapping: %s has no __defaults__',
                        entity)

        logging.log(3, 'Namespace of %s includes: %s', compiled,
                    compiled_module.__dict__.keys())

        if hasattr(compiled, '__globals__'):
            # Remove self to avoid circular references. This will probably only work
            # so long as the function is not reentrant.
            del compiled.__globals__[name]

        # Need this so the source_mapping attribute is available for the context
        # manager to access for runtime errors.
        #
        # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
        # symbol to the compiled module.
        # TODO(mdan): Record this statically in the generated code.
        # TODO(mdan): Rename this attribute to 'autograph_info__'
        source_map_attribute_name = 'ag_source_map'
        if getattr(compiled, source_map_attribute_name, None) is not None:
            # TODO(znado): change input problem errors into TransformError
            raise ValueError('cannot convert %s because is has an attribute '
                             '"%s", which is reserved for AutoGraph.' %
                             (compiled, source_map_attribute_name))
        setattr(compiled, source_map_attribute_name,
                compiled_module.__dict__['ag_source_map__'])

        return compiled
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        errors.report_internal_error(entity, e)
Exemplo n.º 39
0
 def test_entity_to_graph_unsupported_types(self):
     with self.assertRaises(NotImplementedError):
         program_ctx = self._simple_program_ctx()
         conversion.entity_to_graph('dummy', program_ctx, None, None)
Exemplo n.º 40
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
    """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
        ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if strip_decorators is None:
        strip_decorators = ()
    strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=strip_decorators,
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                    arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, compiled_src = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled
Exemplo n.º 41
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(NotImplementedError):
     program_ctx = self._simple_program_ctx()
     conversion.entity_to_graph('dummy', program_ctx, None, None)