コード例 #1
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
コード例 #2
0
ファイル: api.py プロジェクト: DILASSS/tensorflow
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(
          six.itervalues(conversion_map.dependency_cache))))

  return imports + '\n\n' + code
コード例 #3
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap(True, (), ())
        conversion.entity_to_graph(f, conversion_map, None, None)

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        self.assertEqual(
            'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #4
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    conversion_map = conversion.ConversionMap(True, (), ())
    conversion.entity_to_graph(f, conversion_map, None, None)

    self.assertTrue(f in conversion_map.dependency_cache)
    self.assertTrue(g in conversion_map.dependency_cache)
    self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
    self.assertEqual(
        'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id)
    self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #5
0
  def test_entity_to_graph_callable(self):

    def f(a):
      return a

    conversion_map = conversion.ConversionMap(True, (), ())
    ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None)
    self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
    self.assertEqual('tf__f', new_name)
コード例 #6
0
    def test_entity_to_graph_callable(self):
        def f(a):
            return a

        conversion_map = conversion.ConversionMap(True, (), (), None)
        ast, new_name = conversion.entity_to_graph(f, conversion_map, None,
                                                   None)
        self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
        self.assertEqual('tf__f', new_name)
コード例 #7
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
コード例 #8
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap(True, (), (), None)
        conversion.entity_to_graph(f, conversion_map, None, None)

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        # need the extra .body[0] in order to step past the with tf.name_scope('f')
        # that is added automatically
        self.assertEqual(
            'tf__g',
            conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #9
0
ファイル: api.py プロジェクト: DILASSS/tensorflow
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    compiled_node.__dict__.update(inspect_utils.getnamespace(e))
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
コード例 #10
0
  def test_entity_to_graph_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    conversion_map = conversion.ConversionMap(True, (), (), None)
    conversion.entity_to_graph(f, conversion_map, None, None)

    self.assertTrue(f in conversion_map.dependency_cache)
    self.assertTrue(g in conversion_map.dependency_cache)
    self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
    # need the extra .body[0] in order to step past the with tf.name_scope('f')
    # that is added automatically
    self.assertEqual(
        'tf__g',
        conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
    self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #11
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     conversion_map = conversion.ConversionMap(True, (), ())
     conversion.entity_to_graph('dummy', conversion_map, None, None)
コード例 #12
0
 def test_entity_to_graph_unsupported_types(self):
     with self.assertRaises(ValueError):
         conversion_map = conversion.ConversionMap(True, (), (), None)
         conversion.entity_to_graph('dummy', conversion_map, None, None)