Exemplo n.º 1
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types)
    conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
Exemplo n.º 2
0
def to_code(o,
            recursive=True,
            arg_value_hints=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    o: A Python function or class.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types)
    conversion.object_to_graph(o, conversion_map, arg_value_hints)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
Exemplo n.º 3
0
def to_graph(f, arg_value_hints=None):
    """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
    conversion_map = conversion.ConversionMap()
    _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    compiled_node.__dict__.update(six.get_function_globals(f))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Exemplo n.º 4
0
def to_graph(o, arg_value_hints=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    o: A Python function or class.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(o):
    compiled_node.__dict__.update(six.get_function_globals(o))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Exemplo n.º 5
0
    def test_object_to_graph_callable(self):
        def f(a):
            return a

        conversion_map = conversion.ConversionMap()
        ast, new_name = conversion.object_to_graph(f, conversion_map, {})
        self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
        self.assertEqual('tf__f', new_name)
  def test_entity_to_graph_callable(self):

    def f(a):
      return a

    conversion_map = conversion.ConversionMap(True, (), ())
    ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None)
    self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
    self.assertEqual('tf__f', new_name)
Exemplo n.º 7
0
def to_graph(e,
             recursive=True,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types)
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Exemplo n.º 8
0
    def test_object_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap()
        conversion.object_to_graph(f, conversion_map, {})

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        self.assertEqual(
            'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
Exemplo n.º 9
0
def to_code(f, arg_value_hints=None, indentation='  '):
    """Return the equivalent of a function in TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap()
    conversion.object_to_graph(f, conversion_map, arg_value_hints)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
Exemplo n.º 10
0
 def test_entity_to_graph_unsupported_types(self):
   with self.assertRaises(ValueError):
     conversion_map = conversion.ConversionMap(True, (), ())
     conversion.entity_to_graph('dummy', conversion_map, None, None)