Ejemplo n.º 1
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
Ejemplo n.º 2
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
Ejemplo n.º 3
0
  def test_entity_to_graph_class_hierarchy(self):

    class TestBase(object):

      def __init__(self, x='base'):
        self.x = x

      def foo(self):
        return self.x

      def bar(self):
        return self.x

    class TestSubclass(TestBase):

      def __init__(self, y):
        super(TestSubclass, self).__init__('sub')
        self.y = y

      def foo(self):
        return self.y

      def baz(self):
        return self.y

    conversion_map = self._simple_conversion_map()
    conversion.entity_to_graph(TestSubclass, conversion_map, None, None)

    self.assertTrue(TestBase in conversion_map.dependency_cache)
    self.assertTrue(TestSubclass in conversion_map.dependency_cache)
    self.assertEqual('TfTestBase',
                     conversion_map.dependency_cache[TestBase].body[-1].name)
    self.assertEqual(
        'TfTestSubclass',
        conversion_map.dependency_cache[TestSubclass].body[-1].name)
Ejemplo n.º 4
0
    def test_entity_to_graph_class_hierarchy(self):
        class TestBase(object):
            def __init__(self, x='base'):
                self.x = x

            def foo(self):
                return self.x

            def bar(self):
                return self.x

        class TestSubclass(TestBase):
            def __init__(self, y):
                super(TestSubclass, self).__init__('sub')
                self.y = y

            def foo(self):
                return self.y

            def baz(self):
                return self.y

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

        self.assertTrue(TestBase in program_ctx.dependency_cache)
        self.assertTrue(TestSubclass in program_ctx.dependency_cache)
        self.assertEqual('TfTestBase',
                         program_ctx.dependency_cache[TestBase].body[-1].name)
        self.assertEqual(
            'TfTestSubclass',
            program_ctx.dependency_cache[TestSubclass].body[-1].name)
Ejemplo n.º 5
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(
          six.itervalues(conversion_map.dependency_cache))))

  return imports + '\n\n' + code
Ejemplo n.º 6
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
Ejemplo n.º 7
0
  def test_ag_module_cached(self):
    def callee():
      return range(3)

    def caller(a):
      return a()

    conversion_map = self._simple_conversion_map()
    _, _, callee_ns = conversion.entity_to_graph(
        callee, conversion_map, None, None)
    _, _, caller_ns = conversion.entity_to_graph(
        caller, conversion_map, None, None)

    self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
Ejemplo n.º 8
0
    def test_ag_module_cached(self):
        def callee():
            return range(3)

        def caller(a):
            return a()

        program_ctx = self._simple_program_ctx()
        _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None,
                                                     None)
        _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None,
                                                     None)

        self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
Ejemplo n.º 9
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.extend(parser.parse_str(import_line).body)
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        for key, val in inspect_utils.getnamespace(e).items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_node.__dict__:
                compiled_node.__dict__[key] = val
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
Ejemplo n.º 10
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(f, program_ctx, None, None)

        self.assertTrue(f in program_ctx.dependency_cache)
        self.assertTrue(g in program_ctx.dependency_cache)
        f_node = program_ctx.dependency_cache[f][0]
        g_node = program_ctx.dependency_cache[g][0]
        self.assertEqual('tf__f', f_node.name)
        self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
        self.assertEqual('tf__g', g_node.name)
Ejemplo n.º 11
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  module = gast.Module([])
  for dep in reversed(program_ctx.dependency_cache.values()):
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(
      module, source_prefix=program_ctx.required_imports)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_node.__dict__:
      compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Ejemplo n.º 12
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    for key, val in inspect_utils.getnamespace(e).items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_node.__dict__:
        compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Ejemplo n.º 13
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(f, program_ctx, None, None)

    self.assertTrue(f in program_ctx.dependency_cache)
    self.assertTrue(g in program_ctx.dependency_cache)
    f_node = program_ctx.dependency_cache[f][0]
    g_node = program_ctx.dependency_cache[g][0]
    self.assertEqual('tf__f', f_node.name)
    self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
    self.assertEqual('tf__g', g_node.name)
Ejemplo n.º 14
0
    def test_entity_to_graph_callable(self):
        def f(a):
            return a

        conversion_map = conversion.ConversionMap(True, (), (), None)
        ast, new_name = conversion.entity_to_graph(f, conversion_map, None,
                                                   None)
        self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
        self.assertEqual('tf__f', new_name)
Ejemplo n.º 15
0
  def test_entity_to_graph_callable(self):

    def f(a):
      return a

    conversion_map = conversion.ConversionMap(True, (), (), None)
    ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None)
    self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
    self.assertEqual('tf__f', new_name)
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(f, program_ctx, None, None)

        self.assertTrue(f in program_ctx.dependency_cache)
        self.assertTrue(g in program_ctx.dependency_cache)
        self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
        # need the extra .body[0] in order to step past the with tf.name_scope('f')
        # that is added automatically
        self.assertEqual(
            'tf__g',
            program_ctx.dependency_cache[f].body[0].body[0].value.func.id)
        self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
Ejemplo n.º 17
0
  def test_entity_to_graph_callable(self):
    b = 2
    def f(a):
      return a + b

    conversion_map = self._simple_conversion_map()
    ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None)
    self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
    self.assertEqual('tf__f', name)
    self.assertTrue(ns['b'] is b)
Ejemplo n.º 18
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap(True, (), (), None)
        conversion.entity_to_graph(f, conversion_map, None, None)

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        # need the extra .body[0] in order to step past the with tf.name_scope('f')
        # that is added automatically
        self.assertEqual(
            'tf__g',
            conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
Ejemplo n.º 19
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(f, program_ctx, None, None)

    self.assertTrue(f in program_ctx.dependency_cache)
    self.assertTrue(g in program_ctx.dependency_cache)
    self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
    # need the extra .body[0] in order to step past the with tf.name_scope('f')
    # that is added automatically
    self.assertEqual(
        'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id)
    self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
Ejemplo n.º 20
0
    def test_entity_to_graph_class_hierarchy_whitelisted(self):
        class TestSubclass(training.Model):
            def __init__(self, y):
                super(TestSubclass, self).__init__()
                self.built = False

            def call(self, x):
                return 3 * x

        program_ctx = self._simple_program_ctx()
        conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

        self.assertTrue(TestSubclass in program_ctx.dependency_cache)
        self.assertFalse(training.Model in program_ctx.dependency_cache)
        self.assertEqual(
            'Model',
            program_ctx.dependency_cache[TestSubclass][0].names[0].name)
        self.assertEqual('TfTestSubclass',
                         program_ctx.dependency_cache[TestSubclass][-1].name)
Ejemplo n.º 21
0
  def test_entity_to_graph_callable(self):
    b = 2
    def f(a):
      return a + b

    conversion_map = self._simple_conversion_map()
    ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None)
    self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
    self.assertEqual('tf__f', name)
    self.assertTrue(ns['b'] is b)
Ejemplo n.º 22
0
  def test_entity_to_graph_callable(self):
    b = 2
    def f(a):
      return a + b

    program_ctx = self._simple_program_ctx()
    nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertIs(ns['b'], b)
Ejemplo n.º 23
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    conversion_map = conversion.ConversionMap(True, (), (), None)
    conversion.entity_to_graph(f, conversion_map, None, None)

    self.assertTrue(f in conversion_map.dependency_cache)
    self.assertTrue(g in conversion_map.dependency_cache)
    self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
    # need the extra .body[0] in order to step past the with tf.name_scope('f')
    # that is added automatically
    self.assertEqual(
        'tf__g',
        conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
    self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
Ejemplo n.º 24
0
  def test_entity_to_graph_class_hierarchy_whitelisted(self):

    class TestSubclass(training.Model):

      def __init__(self, y):
        super(TestSubclass, self).__init__()
        self.built = False

      def call(self, x):
        return 3 * x

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

    self.assertTrue(TestSubclass in program_ctx.dependency_cache)
    self.assertFalse(training.Model in program_ctx.dependency_cache)
    self.assertEqual(
        'Model',
        program_ctx.dependency_cache[TestSubclass].body[0].names[0].name)
    self.assertEqual('TfTestSubclass',
                     program_ctx.dependency_cache[TestSubclass].body[-1].name)
Ejemplo n.º 25
0
  def test_entity_to_graph_class_hierarchy_whitelisted(self):

    class TestSubclass(training.Model):

      def __init__(self, y):
        super(TestSubclass, self).__init__()
        self.built = False

      def call(self, x):
        return 3 * x

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

    self.assertTrue(TestSubclass in program_ctx.dependency_cache)
    self.assertFalse(training.Model in program_ctx.dependency_cache)
    self.assertEqual(
        'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
    # The returned nodes will include:
    # <import nodes>, <class node>, <assignment node>
    self.assertEqual('TfTestSubclass',
                     program_ctx.dependency_cache[TestSubclass][-2].name)
Ejemplo n.º 26
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
Ejemplo n.º 27
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(program_ctx.dependency_cache))))

    return program_ctx.required_imports + '\n\n' + code
Ejemplo n.º 28
0
  def test_entity_to_graph_class_hierarchy(self):

    class TestBase(object):

      def __init__(self, x='base'):
        self.x = x

      def foo(self):
        return self.x

      def bar(self):
        return self.x

    class TestSubclass(TestBase):

      def __init__(self, y):
        super(TestSubclass, self).__init__('sub')
        self.y = y

      def foo(self):
        return self.y

      def baz(self):
        return self.y

    program_ctx = self._simple_program_ctx()
    conversion.entity_to_graph(TestSubclass, program_ctx, None, None)

    self.assertTrue(TestBase in program_ctx.dependency_cache)
    self.assertTrue(TestSubclass in program_ctx.dependency_cache)
    # The returned nodes will include:
    # <import nodes>, <class node>, <assignment node>
    self.assertEqual('TfTestBase',
                     program_ctx.dependency_cache[TestBase][-2].name)
    self.assertEqual('TfTestSubclass',
                     program_ctx.dependency_cache[TestSubclass][-2].name)
Ejemplo n.º 29
0
 def test_entity_to_graph_unsupported_types(self):
     with self.assertRaises(ValueError):
         conversion_map = conversion.ConversionMap(True, (), (), None)
         conversion.entity_to_graph('dummy', conversion_map, None, None)
Ejemplo n.º 30
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
    creates TF a graph that has the same functionality as the original entity.
  Raises:
    ValueError: If the converted function defines or refers to symbol names that
    are reserved for AutoGraph.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled_fn = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled_fn, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled_fn, source_map_attribute_name))
  setattr(compiled_fn, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Ejemplo n.º 31
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled
Ejemplo n.º 32
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     conversion_map = self._simple_conversion_map()
     conversion.entity_to_graph('dummy', conversion_map, None, None)
Ejemplo n.º 33
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     conversion_map = conversion.ConversionMap(True, (), (), None)
     conversion.entity_to_graph('dummy', conversion_map, None, None)
Ejemplo n.º 34
0
  def test_entity_to_graph_lambda(self):
    f = lambda a: a

    with self.assertRaises(NotImplementedError):
      program_ctx = self._simple_program_ctx()
      conversion.entity_to_graph(f, program_ctx, None, None)
Ejemplo n.º 35
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     program_ctx = self._simple_program_ctx()
     conversion.entity_to_graph('dummy', program_ctx, None, None)
Ejemplo n.º 36
0
 def test_entity_to_graph_unsupported_types(self):
     with self.assertRaises(ValueError):
         program_ctx = self._simple_program_ctx()
         conversion.entity_to_graph('dummy', program_ctx, None, None)
Ejemplo n.º 37
0
    def test_entity_to_graph_lambda(self):
        f = lambda a: a

        with self.assertRaises(NotImplementedError):
            program_ctx = self._simple_program_ctx()
            conversion.entity_to_graph(f, program_ctx, None, None)
Ejemplo n.º 38
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                    arg_types)

    nodes = []
    for dep in reversed(program_ctx.dependency_cache.values()):
        nodes.extend(dep)
    compiled_module, compiled_src = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled
Ejemplo n.º 39
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
    creates TF a graph that has the same functionality as the original entity.
  Raises:
    ValueError: If the converted function defines or refers to symbol names that
    are reserved for AutoGraph.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled
Ejemplo n.º 40
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     conversion_map = self._simple_conversion_map()
     conversion.entity_to_graph('dummy', conversion_map, None, None)
Ejemplo n.º 41
0
  def test_entity_to_graph_lambda(self):
    f = lambda a: a

    with self.assertRaises(NotImplementedError):
      conversion_map = self._simple_conversion_map()
      conversion.entity_to_graph(f, conversion_map, None, None)