def compiled(self, node, namespace, *symbols):
    source = None

    self.dynamic_calls = []
    def converted_call(*args):
      """Mock version of api.converted_call."""
      self.dynamic_calls.append(args)
      return 7

    try:
      result, source = compiler.ast_to_object(node)
      result.tf = self.make_fake_mod('fake_tf', *symbols)
      fake_ag = self.make_fake_mod('fake_ag', converted_call)
      fake_ag.__dict__.update(operators.__dict__)
      fake_ag.__dict__['utils'] = utils
      fake_ag.__dict__['rewrite_graph_construction_error'] = (
          errors.rewrite_graph_construction_error)
      result.__dict__['ag__'] = fake_ag
      for k, v in namespace.items():
        result.__dict__[k] = v
      yield result
    except Exception:  # pylint:disable=broad-except
      if source is None:
        print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
      else:
        print('Offending compiled code:\n%s' % source)
      raise
    def compiled(self, node, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        try:
            result, source = compiler.ast_to_object(node)
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            result.autograph_utils = utils
            result.autograph_api = self.make_fake_mod('fake_api',
                                                      converted_call)
            result.__dict__['__ops'] = operators
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
    def compiled(self, node, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        try:
            result, source = compiler.ast_to_object(node)
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__['utils'] = utils
            fake_ag.__dict__['rewrite_graph_construction_error'] = (
                errors.rewrite_graph_construction_error)
            result.__dict__['ag__'] = fake_ag
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
Exemple #4
0
 def test_codegen_gens(self):
   np.random.seed(0)
   for _ in range(1000):
     node = codegen.generate_random_functiondef()
     fn = compiler.ast_to_object(node)
     self.assertIsNotNone(
         fn, 'Generated invalid AST that could not convert to source.')
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source = compiler.ast_to_object(node)

    expected_source = """
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Exemple #6
0
    def test_ast_to_object(self):
        node = gast.FunctionDef(
            name='f',
            args=gast.arguments(args=[gast.Name('a', gast.Param(), None)],
                                vararg=None,
                                kwonlyargs=[],
                                kwarg=None,
                                defaults=[],
                                kw_defaults=[]),
            body=[
                gast.Return(
                    gast.BinOp(op=gast.Add(),
                               left=gast.Name('a', gast.Load(), None),
                               right=gast.Num(1)))
            ],
            decorator_list=[],
            returns=None)

        module, source = compiler.ast_to_object(node)

        expected_source = """
      def f(a):
        return a + 1
    """
        self.assertEqual(
            textwrap.dedent(expected_source).strip(), source.strip())
        self.assertEqual(2, module.f(1))
        with open(module.__file__, 'r') as temp_output:
            self.assertEqual(
                textwrap.dedent(expected_source).strip(),
                temp_output.read().strip())
Exemple #7
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.extend(parser.parse_str(import_line).body)
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        for key, val in inspect_utils.getnamespace(e).items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_node.__dict__:
                compiled_node.__dict__[key] = val
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
Exemple #8
0
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _ = parser.parse_entity(test_function)
   node = anf.transform(node.body[0], self._simple_source_info())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
Exemple #9
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    for key, val in inspect_utils.getnamespace(e).items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_node.__dict__:
        compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Exemple #10
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  module = gast.Module([])
  for dep in reversed(program_ctx.dependency_cache.values()):
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(
      module, source_prefix=program_ctx.required_imports)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_node.__dict__:
      compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Exemple #11
0
 def _remover_wrapper(self, f, remove_decorators):
   namespace = {
       'self_removing_decorator': self_removing_decorator,
       'simple_decorator': simple_decorator
   }
   node = self.parse_and_analyze(f, namespace)
   node, _ = decorators.transform(node, remove_decorators=remove_decorators)
   result, _ = compiler.ast_to_object(node)
   return getattr(result, f.__name__)
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   output = parser.parse_str('b = 3')
   output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
   result, _ = compiler.ast_to_object(output)
   self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   node = parser.parse_str('def f(b): pass').body[0]
   node.body.append(ast.Return(d))
   result, _ = compiler.ast_to_object(node)
   self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
Exemple #14
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     node = parser.parse_str('def f(b): pass').body[0]
     node.body.append(ast.Return(d))
     result, _ = compiler.ast_to_object(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     output = parser.parse_str('b = 3')
     output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), )
     result, _ = compiler.ast_to_object(output)
     self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
Exemple #16
0
    def test_basic(self):
        def test_function():
            a = 0
            return a

        node, _ = parser.parse_entity(test_function)
        node = anf.transform(node.body[0], self._simple_source_info())
        result, _ = compiler.ast_to_object(node)
        self.assertEqual(test_function(), result.test_function())
 def _remover_wrapper(self, f, remove_decorators):
   namespace = {
       'self_removing_decorator': self_removing_decorator,
       'simple_decorator': simple_decorator
   }
   node = self.parse_and_analyze(f, namespace)
   node, _ = decorators.transform(node, remove_decorators=remove_decorators)
   result, _ = compiler.ast_to_object(node)
   return getattr(result, f.__name__)
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result, _ = compiler.ast_to_object(node)

    self.assertEquals((2, 3), result.test_fn(2, 3))
Exemple #20
0
  def test_noop(self):

    def test_fn(a):
      return a

    node = self.parse_and_analyze(test_fn, {})
    node = decorators.transform(node, self.ctx)
    result, _ = compiler.ast_to_object(node)

    self.assertEqual(1, result.test_fn(1))
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn())
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _ = compiler.ast_to_object(node)

        self.assertEquals((2, 3), result.test_fn(2, 3))
Exemple #23
0
    def test_noop(self):
        def test_fn(a):
            return a

        node = self.parse_and_analyze(test_fn, {})
        node, deps = decorators.transform(node, remove_decorators=())
        result, _ = compiler.ast_to_object(node)

        self.assertFalse(deps)
        self.assertEqual(1, result.test_fn(1))
  def test_replace_variable(self):
    template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

    node = templates.replace(template, a='b')[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
  def test_replace_function_name(self):
    template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

    node = templates.replace(template, fname='test_fn')[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
  def test_noop(self):

    def test_fn(a):
      return a

    node = self.parse_and_analyze(test_fn, {})
    node, deps = decorators.transform(node, remove_decorators=())
    result, _ = compiler.ast_to_object(node)

    self.assertFalse(deps)
    self.assertEqual(1, result.test_fn(1))
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
Exemple #29
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(
                compiler.ast_to_object(
                    parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = compiler.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEquals(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(15, result.test_fn())
Exemple #34
0
 def _transform(self, f, autograph_decorators):
   namespace = {
       'self_transform_decorator': self_transform_decorator,
       'simple_decorator': simple_decorator,
       'converter_testing': converter_testing,
   }
   node, ctx = self.prepare(
       f,
       namespace,
       recursive=False,
       autograph_decorators=autograph_decorators)
   node = decorators.transform(node, ctx)
   import_line = '\n'.join(ctx.program.additional_imports)
   result, _ = compiler.ast_to_object(node, source_prefix=import_line)
   return getattr(result, f.__name__)
  def test_replace_name_with_call(self):
    template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

    source = parser.parse_expression('f()(b)')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(15, result.test_fn())
  def test_replace_code_block(self):
    template = """
      def test_fn(a):
        block
        return a
    """

    node = templates.replace(
        template,
        block=[
            gast.Assign([
                gast.Name('a', None, None)
            ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
        ] * 2)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn(1))
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
  def compiled(self, node, *symbols):
    source = None

    self.dynamic_calls = []
    def converted_call(*args):
      """Mock version of api.converted_call."""
      self.dynamic_calls.append(args)
      return 7

    try:
      result, source = compiler.ast_to_object(node)
      result.tf = self.make_fake_mod('fake_tf', *symbols)
      result.autograph_utils = utils
      result.autograph_api = self.make_fake_mod('fake_api', converted_call)
      result.__dict__['__ops'] = operators
      yield result
    except Exception:  # pylint:disable=broad-except
      if source is None:
        print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
      else:
        print('Offending compiled code:\n%s' % source)
      raise
Exemple #41
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled
Exemple #42
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
    creates TF a graph that has the same functionality as the original entity.
  Raises:
    ValueError: If the converted function defines or refers to symbol names that
    are reserved for AutoGraph.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled
Exemple #43
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
    creates TF a graph that has the same functionality as the original entity.
  Raises:
    ValueError: If the converted function defines or refers to symbol names that
    are reserved for AutoGraph.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled_fn = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled_fn, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled_fn, source_map_attribute_name))
  setattr(compiled_fn, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                    arg_types)

    nodes = []
    for dep in reversed(program_ctx.dependency_cache.values()):
        nodes.extend(dep)
    compiled_module, compiled_src = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled