コード例 #1
0
  def compiled(self, node, namespace, *symbols):
    source = None

    self.dynamic_calls = []
    def converted_call(*args):
      """Mock version of api.converted_call."""
      self.dynamic_calls.append(args[3:])  # args only; see api.converted_call
      return RESULT_OF_MOCK_CONVERTED_CALL

    try:
      result, source, source_map = compiler.ast_to_object(
          node, include_source_map=True)
      # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

      # TODO(mdan): Move this into self.prepare()
      result.tf = self.make_fake_mod('fake_tf', *symbols)
      fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                   converter.ConversionOptions)
      fake_ag.__dict__.update(operators.__dict__)
      fake_ag.__dict__.update(special_functions.__dict__)
      fake_ag.ConversionOptions = converter.ConversionOptions
      fake_ag.Feature = converter.Feature
      fake_ag.utils = utils
      fake_ag.function_scope = function_wrapping.function_scope
      result.ag__ = fake_ag
      result.ag_source_map__ = source_map
      for k, v in namespace.items():
        result.__dict__[k] = v
      yield result
    except Exception:  # pylint:disable=broad-except
      if source is None:
        print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
      else:
        print('Offending compiled code:\n%s' % source)
      raise
コード例 #2
0
  def compiled(self, node, namespace, *symbols):
    source = None

    self.dynamic_calls = []
    def converted_call(*args):
      """Mock version of api.converted_call."""
      self.dynamic_calls.append(args)
      return 7

    try:
      result, source = compiler.ast_to_object(node, include_source_map=True)

      # TODO(mdan): Move this into self.prepare()
      result.tf = self.make_fake_mod('fake_tf', *symbols)
      fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                   converter.ConversionOptions)
      fake_ag.__dict__.update(operators.__dict__)
      fake_ag.__dict__.update(special_functions.__dict__)
      fake_ag.__dict__['utils'] = utils
      fake_ag.__dict__['rewrite_graph_construction_error'] = (
          errors.rewrite_graph_construction_error)
      fake_ag.__dict__['function_scope'] = function_wrapping.function_scope
      result.__dict__['ag__'] = fake_ag
      for k, v in namespace.items():
        result.__dict__[k] = v
      yield result
    except Exception:  # pylint:disable=broad-except
      if source is None:
        print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
      else:
        print('Offending compiled code:\n%s' % source)
      raise
コード例 #3
0
ファイル: compiler_test.py プロジェクト: AnishShah/tensorflow
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source = compiler.ast_to_object(node)

    expected_source = """
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
コード例 #4
0
    def module(self):
        """Constructs an `instructions.Module` for this `Context`.

    Returns:
      module: An `instructions.Module` representing the batched computation
        defined by all the functions decorated with `batch` in this `Context` so
        far.
    """
        if self._module is not None:
            return self._module
        ab = dsl.ProgramBuilder()
        function_objects = []
        for function, type_inference in self._tagged_functions:
            declared = ab.declare_function(function.__name__, type_inference)
            function_objects.append(declared)
        for function, _ in self._tagged_functions:
            name = function.__name__
            node, ctx = _parse_and_analyze(function, self.function_names())
            # print(compiler.ast_to_source(node, indentation='  '))
            node = _AutoBatchingTransformer(self.function_names(), [
                scoped_name
                for scoped_name, _ in _environment(function, [name])
            ], ctx).visit(node)
            # print(compiler.ast_to_source(node, indentation='  '))
            builder_module, _, _ = compiler.ast_to_object(node)
            for scoped_name, val in _environment(function, [name]):
                builder_module.__dict__[scoped_name] = val
            builder = getattr(builder_module, name)
            builder(ab, function_objects)
        self._module = ab.module()
        return self._module
コード例 #5
0
  def test_replace_code_block(self):
    template = """
      def test_fn(a):
        block
        return a
    """

    class ShouldBeReplaced(object):
      pass

    node = templates.replace(
        template,
        block=[
            gast.Assign(
                [
                    gast.Name(
                        'a',
                        ctx=ShouldBeReplaced,
                        annotation=None,
                        type_comment=None)
                ],
                gast.BinOp(
                    gast.Name(
                        'a',
                        ctx=ShouldBeReplaced,
                        annotation=None,
                        type_comment=None), gast.Add(),
                    gast.Constant(1, kind=None)),
            ),
        ] * 2)[0]
    result, _, _ = compiler.ast_to_object(node)
    self.assertEqual(3, result.test_fn(1))
コード例 #6
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        try:
            result, source = compiler.ast_to_object(node,
                                                    include_source_map=True)

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__['utils'] = utils
            fake_ag.__dict__['rewrite_graph_construction_error'] = (
                errors.rewrite_graph_construction_error)
            fake_ag.__dict__[
                'function_scope'] = function_wrapping.function_scope
            result.__dict__['ag__'] = fake_ag
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
コード例 #7
0
ファイル: codegen_test.py プロジェクト: Harryi0/tinyML
 def test_codegen_gens(self):
   np.random.seed(0)
   for _ in range(1000):
     node = codegen.generate_random_functiondef()
     fn = compiler.ast_to_object(node)
     self.assertIsNotNone(
         fn, 'Generated invalid AST that could not convert to source.')
コード例 #8
0
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source, _ = compiler.ast_to_object(node)

    expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
コード例 #9
0
ファイル: anf_test.py プロジェクト: AnishShah/tensorflow
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _ = parser.parse_entity(test_function)
   node = anf.transform(node.body[0], self._simple_source_info())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
コード例 #10
0
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _, _ = parser.parse_entity(test_function, future_imports=())
   node = anf.transform(node, self._simple_context())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
コード例 #11
0
ファイル: ast_util_test.py プロジェクト: AnishShah/tensorflow
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   node = parser.parse_str('def f(b): pass').body[0]
   node.body.append(ast.Return(d))
   result, _ = compiler.ast_to_object(node)
   self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
コード例 #12
0
ファイル: anf_test.py プロジェクト: d813s909q/tensortflow
    def test_basic(self):
        def test_function():
            a = 0
            return a

        node, _ = parser.parse_entity(test_function)
        node = anf.transform(node.body[0], self._simple_source_info())
        result, _ = compiler.ast_to_object(node)
        self.assertEqual(test_function(), result.test_function())
コード例 #13
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     node = parser.parse_str('def f(b): pass').body[0]
     node.body.append(ast.Return(d))
     result, _ = compiler.ast_to_object(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
コード例 #14
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn())
コード例 #15
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _ = compiler.ast_to_object(node)

        self.assertEquals((2, 3), result.test_fn(2, 3))
コード例 #16
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result, _ = compiler.ast_to_object(node)

    self.assertEquals((2, 3), result.test_fn(2, 3))
コード例 #17
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
コード例 #18
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
コード例 #19
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
コード例 #20
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_variable(self):
    template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

    node = templates.replace(template, a='b')[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
コード例 #21
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_function_name(self):
    template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

    node = templates.replace(template, fname='test_fn')[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
コード例 #22
0
 def _transform(self, f, strip_decorators):
   namespace = {
       'self_transform_decorator': self_transform_decorator,
       'simple_decorator': simple_decorator,
       'converter_testing': converter_testing,
   }
   node, ctx = self.prepare(
       f, namespace, recursive=False, strip_decorators=strip_decorators)
   node = decorators.transform(node, ctx)
   import_line = '\n'.join(ctx.program.additional_imports)
   result, _ = compiler.ast_to_object(node, source_prefix=import_line)
   return getattr(result, f.__name__)
コード例 #23
0
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
  """Returns a (possibly cached) factory for the converted result of entity."""
  # The cache key is the entity's code object if it defined one, otherwise it's
  # the entity itself. Keying by the code object allows caching of functions
  # that are dynamically created e.g. in a loop.
  if hasattr(entity, '__code__'):
    key = entity.__code__
  else:
    key = entity

  # The cache subkey encompases any conversion options on which the generated
  # code may depend.
  # The cached factory includes the necessary definitions to distinguish
  # between the global and non-global free variables. For this reason, the
  # cache subkey includes the names of the free non-globals.
  subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

  with _CACHE_LOCK:
    # The cache values are _ConvertedEntityFactoryInfo objects.
    if _CACHE.has(key, subkey):
      # TODO(mdan): Check whether the module is still loaded.
      converted_entity_info = _CACHE[key][subkey]
      logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity,
                  key, subkey, converted_entity_info)
      return converted_entity_info

    logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key,
                subkey)

    nodes, converted_name, entity_info = convert_entity_to_ast(
        entity, program_ctx)

    namer = naming.Namer(entity_info.namespace)
    factory_factory_name = namer.new_symbol('create_converted_entity_factory',
                                            ())
    factory_name = namer.new_symbol('create_converted_entity', ())
    nodes = _wrap_into_dynamic_factory(nodes, converted_name,
                                       factory_factory_name, factory_name,
                                       free_nonglobal_var_names,
                                       entity_info.future_features)

    module, _, source_map = compiler.ast_to_object(
        nodes, include_source_map=True)
    module_name = module.__name__

    converted_entity_info = _ConvertedEntityFactoryInfo(
        module_name=module_name,
        converted_name=converted_name,
        factory_factory_name=factory_factory_name,
        source_map=source_map)
    _CACHE[key][subkey] = converted_entity_info
    return converted_entity_info
コード例 #24
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        node, _, _ = parser.parse_entity(test_fn, future_imports=())

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(compiler.ast_to_object([node])[0].test_fn))
コード例 #25
0
    def test_parser_compile_identity(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        node, _ = parser.parse_entity(test_fn, future_features=())
        module, _, _ = compiler.ast_to_object(node)

        self.assertEqual(textwrap.dedent(tf_inspect.getsource(test_fn)),
                         tf_inspect.getsource(module.test_fn))
コード例 #26
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        _, _, all_nodes = parser.parse_entity(test_fn)

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(compiler.ast_to_object(all_nodes)[0].test_fn))
コード例 #27
0
 def _transform(self, f, strip_decorators):
     namespace = {
         'self_transform_decorator': self_transform_decorator,
         'simple_decorator': simple_decorator,
         'converter_testing': converter_testing,
     }
     node, ctx = self.prepare(f,
                              namespace,
                              recursive=False,
                              strip_decorators=strip_decorators)
     node = decorators.transform(node, ctx)
     import_line = '\n'.join(ctx.program.additional_imports)
     result, _ = compiler.ast_to_object(node, source_prefix=import_line)
     return getattr(result, f.__name__)
コード例 #28
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = compiler.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEquals(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
コード例 #29
0
ファイル: compiler_test.py プロジェクト: AnishShah/tensorflow
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
コード例 #30
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
コード例 #31
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_name_with_call(self):
    template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

    source = parser.parse_expression('f()(b)')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(15, result.test_fn())
コード例 #32
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(15, result.test_fn())
コード例 #33
0
  def test_parser_compile_identity(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    node, _ = parser.parse_entity(test_fn, future_features=())
    module, _, _ = compiler.ast_to_object(node)

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(module.test_fn))
コード例 #34
0
ファイル: conversion.py プロジェクト: adit-chandra/tensorflow
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
  """Returns a (possibly cached) factory for the converted result of entity."""
  # The cache key is the entity's code object if it defined one, otherwise it's
  # the entity itself. Keying by the code object allows caching of functions
  # that are dynamically created e.g. in a loop.
  if hasattr(entity, '__code__'):
    key = entity.__code__
  else:
    key = entity

  # The cache subkey encompases any conversion options on which the generated
  # code may depend.
  # The cached factory includes the necessary definitions to distinguish
  # between the global and non-global free variables. For this reason, the
  # cache subkey includes the names of the free non-globals.
  subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

  # The cache values are _ConvertedEntityFactoryInfo objects.
  if _CACHE.has(key, subkey):
    # TODO(mdan): Check whether the module is still loaded.
    converted_entity_info = _CACHE[key][subkey]
    logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity, key,
                subkey, converted_entity_info)
    return converted_entity_info

  logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key,
              subkey)

  nodes, converted_name, entity_info = convert_entity_to_ast(
      entity, program_ctx)

  namer = naming.Namer(entity_info.namespace)
  factory_factory_name = namer.new_symbol('create_converted_entity_factory', ())
  factory_name = namer.new_symbol('create_converted_entity', ())
  nodes = _wrap_into_dynamic_factory(
      nodes, converted_name, factory_factory_name, factory_name,
      free_nonglobal_var_names, entity_info.future_features)

  module, _, source_map = compiler.ast_to_object(nodes, include_source_map=True)
  module_name = module.__name__

  converted_entity_info = _ConvertedEntityFactoryInfo(
      module_name=module_name,
      converted_name=converted_name,
      factory_factory_name=factory_factory_name,
      source_map=source_map)
  _CACHE[key][subkey] = converted_entity_info
  return converted_entity_info
コード例 #35
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
コード例 #36
0
ファイル: templates_test.py プロジェクト: melisadr/tensorflow
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
コード例 #37
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_code_block(self):
    template = """
      def test_fn(a):
        block
        return a
    """

    node = templates.replace(
        template,
        block=[
            gast.Assign([
                gast.Name('a', None, None)
            ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
        ] * 2)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn(1))
コード例 #38
0
ファイル: templates_test.py プロジェクト: ziky90/tensorflow
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
コード例 #39
0
    def compiled(self, node, namespace, symbols=()):
        source = None

        self.dynamic_calls = []

        # See api.converted_call
        def converted_call(f,
                           args,
                           kwargs,
                           unused_opts=None,
                           unused_function_ctx=None):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append((args, kwargs))
            if kwargs is None:
                kwargs = {}
            return f(*args, **kwargs)

        try:
            result, source, source_map = compiler.ast_to_object(
                node, include_source_map=True)
            # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__.update(special_functions.__dict__)
            fake_ag.ConversionOptions = converter.ConversionOptions
            fake_ag.Feature = converter.Feature
            fake_ag.utils = utils
            fake_ag.FunctionScope = function_wrappers.FunctionScope
            result.ag__ = fake_ag
            result.ag_source_map__ = source_map
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
コード例 #40
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        class ConversionOptions(object):
            """Mock version of api.ConversionOptions."""
            def __init__(self, recursive):
                self.recursive = recursive

            @classmethod
            def new(cls, recursive):
                cls(recursive)

        try:
            result, source = compiler.ast_to_object(node,
                                                    include_source_map=True)

            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__['utils'] = utils
            fake_ag.__dict__['rewrite_graph_construction_error'] = (
                errors.rewrite_graph_construction_error)
            result.__dict__['ag__'] = fake_ag
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
コード例 #41
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(
                args[3:])  # args only; see api.converted_call
            return RESULT_OF_MOCK_CONVERTED_CALL

        try:
            result, source, source_map = compiler.ast_to_object(
                node, include_source_map=True)
            # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__.update(special_functions.__dict__)
            fake_ag.ConversionOptions = converter.ConversionOptions
            fake_ag.Feature = converter.Feature
            fake_ag.utils = utils
            fake_ag.rewrite_graph_construction_error = (
                errors.rewrite_graph_construction_error)
            fake_ag.function_scope = function_wrapping.function_scope
            result.ag__ = fake_ag
            result.ag_source_map__ = source_map
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
コード例 #42
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()
        opts_ast = opts.to_ast()

        template = '''
    def test_fn():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _, _ = compiler.ast_to_object(opts_packed)
        reparsed.__dict__['ag__'] = self.make_fake_mod(
            'fake_ag', converter.ConversionOptions, converter.Feature)

        reparsed_opts = reparsed.test_fn()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.user_requested, False)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
コード例 #43
0
  def test_to_ast(self):
    opts = converter.ConversionOptions()
    opts_ast = opts.to_ast()

    template = '''
    def test_fn():
      return opts_ast
    '''
    opts_packed = templates.replace(template, opts_ast=opts_ast)

    reparsed, _, _ = compiler.ast_to_object(opts_packed)
    reparsed.__dict__['ag__'] = self.make_fake_mod(
        'fake_ag', converter.ConversionOptions, converter.Feature)

    reparsed_opts = reparsed.test_fn()

    self.assertEqual(opts.recursive, reparsed_opts.recursive)
    self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
    self.assertEqual(
        opts.internal_convert_user_code,
        reparsed_opts.internal_convert_user_code)
    self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
コード例 #44
0
  def test_to_ast(self):
    opts = converter.ConversionOptions()

    namer = converter_testing.FakeNamer()
    program_ctx = converter.ProgramContext(
        options=opts,
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=())
    entity_info = transformer.EntityInfo(
        source_code='',
        source_file='<fragment>',
        namespace={},
        arg_values=None,
        arg_types={},
        owner_type=None)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    opts_ast = opts.to_ast(ctx)

    template = '''
    def test_fn():
      return opts_ast
    '''
    opts_packed = templates.replace(template, opts_ast=opts_ast)

    reparsed, _ = compiler.ast_to_object(opts_packed)
    reparsed.__dict__['ag__'] = self.make_fake_mod(
        'fake_ag', converter.ConversionOptions, converter.Feature)

    reparsed_opts = reparsed.test_fn()

    self.assertEqual(opts.recursive, reparsed_opts.recursive)
    self.assertEqual(opts.verbose, reparsed_opts.verbose)
    self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
    self.assertEqual(
        opts.internal_convert_user_code,
        reparsed_opts.internal_convert_user_code)
    self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
コード例 #45
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()

        namer = converter_testing.FakeNamer()
        program_ctx = converter.ProgramContext(options=opts,
                                               partial_types=None,
                                               autograph_module=None,
                                               uncompiled_modules=())
        entity_info = transformer.EntityInfo(source_code='',
                                             source_file='<fragment>',
                                             namespace={},
                                             arg_values=None,
                                             arg_types={},
                                             owner_type=None)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        opts_ast = opts.to_ast(ctx)

        template = '''
    def test_fn():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _ = compiler.ast_to_object(opts_packed)
        reparsed.__dict__['ag__'] = self.make_fake_mod(
            'fake_ag', converter.ConversionOptions, converter.Feature)

        reparsed_opts = reparsed.test_fn()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.verbose, reparsed_opts.verbose)
        self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
コード例 #46
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
    """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
        ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if strip_decorators is None:
        strip_decorators = ()
    strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=strip_decorators,
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                    arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, compiled_src = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled
コード例 #47
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if experimental_strip_decorators is None:
        experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if tf_inspect.isfunction(entity):
        compiled.__defaults__ = entity.__defaults__

    if hasattr(compiled, '__globals__'):
        # Remove self to avoid circular references. This will probably only work
        # so long as the function is not reentrant.
        del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
コード例 #48
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        nodes, name, namespace = conversion.entity_to_graph(
            entity, program_ctx, arg_values, arg_types)

        compiled_module, _ = compiler.ast_to_object(
            nodes,
            source_prefix=program_ctx.required_imports,
            include_source_map=True)

        # The compiled code should see everything the entry entity saw.
        # TODO(mdan): This might not work well if the call tree spans modules?
        for key, val in namespace.items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_module.__dict__:
                compiled_module.__dict__[key] = val
        compiled = getattr(compiled_module, name)

        if hasattr(entity, '__defaults__'):
            logging.log(3, 'Default args mapping: %s has: %s', entity,
                        entity.__defaults__)
            compiled.__defaults__ = entity.__defaults__
        else:
            logging.log(3, 'Default args mapping: %s has no __defaults__',
                        entity)

        logging.log(3, 'Namespace of %s includes: %s', compiled,
                    compiled_module.__dict__.keys())

        if hasattr(compiled, '__globals__'):
            # Remove self to avoid circular references. This will probably only work
            # so long as the function is not reentrant.
            del compiled.__globals__[name]

        # Need this so the source_mapping attribute is available for the context
        # manager to access for runtime errors.
        #
        # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
        # symbol to the compiled module.
        # TODO(mdan): Record this statically in the generated code.
        # TODO(mdan): Rename this attribute to 'autograph_info__'
        source_map_attribute_name = 'ag_source_map'
        if getattr(compiled, source_map_attribute_name, None) is not None:
            # TODO(znado): change input problem errors into TransformError
            raise ValueError('cannot convert %s because is has an attribute '
                             '"%s", which is reserved for AutoGraph.' %
                             (compiled, source_map_attribute_name))
        setattr(compiled, source_map_attribute_name,
                compiled_module.__dict__['ag_source_map__'])

        return compiled
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        errors.report_internal_error(entity, e)
コード例 #49
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
      ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  if strip_decorators is None:
    strip_decorators = ()
  strip_decorators += (convert, do_not_convert, converted_call)

  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=verbose,
          strip_decorators=strip_decorators),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.conversion_order):
    nodes.extend(program_ctx.dependency_cache[dep])

  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  return compiled
コード例 #50
0
ファイル: api.py プロジェクト: kylin9872/tensorflow
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)