コード例 #1
0
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]

  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, conversion_map.api_module)
  namer = conversion_map.new_namer(namespace)

  ctx = context.EntityContext(
      namer=namer,
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type,
      recursive=conversion_map.recursive)
  node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  conversion_map.update_name_map(namer)
  # TODO(mdan): Use this at compilation.
  conversion_map.additional_imports.update(deps)

  return node, new_name
コード例 #2
0
def class_to_graph(c, conversion_map):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  class_namespace = None
  for _, m in members:
    node, _ = function_to_graph(
        m,
        conversion_map=conversion_map,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_namespace is None:
      class_namespace = inspect_utils.getnamespace(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
コード例 #3
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.extend(parser.parse_str(import_line).body)
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(inspect_utils.getnamespace(e))
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
コード例 #4
0
ファイル: api.py プロジェクト: DILASSS/tensorflow
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    compiled_node.__dict__.update(inspect_utils.getnamespace(e))
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
コード例 #5
0
    def test_getnamespace_hermetic(self):

        # Intentionally hiding the global function to make sure we don't overwrite
        # it in the global namespace.
        free_function = object()  # pylint:disable=redefined-outer-name

        def test_fn():
            return free_function

        ns = inspect_utils.getnamespace(test_fn)
        globs = six.get_function_globals(test_fn)
        self.assertTrue(ns['free_function'] is free_function)
        self.assertFalse(globs['free_function'] is free_function)
コード例 #6
0
  def test_getnamespace_hermetic(self):

    # Intentionally hiding the global function to make sure we don't overwrite
    # it in the global namespace.
    free_function = object()  # pylint:disable=redefined-outer-name

    def test_fn():
      return free_function

    ns = inspect_utils.getnamespace(test_fn)
    globs = six.get_function_globals(test_fn)
    self.assertTrue(ns['free_function'] is free_function)
    self.assertFalse(globs['free_function'] is free_function)
コード例 #7
0
    def test_getnamespace_locals(self):
        def called_fn():
            return 0

        closed_over_list = []
        closed_over_primitive = 1

        def local_fn():
            closed_over_list.append(1)
            local_var = 1
            return called_fn() + local_var + closed_over_primitive

        ns = inspect_utils.getnamespace(local_fn)
        self.assertEqual(ns['called_fn'], called_fn)
        self.assertEqual(ns['closed_over_list'], closed_over_list)
        self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
        self.assertTrue('local_var' not in ns)
コード例 #8
0
  def test_getnamespace_locals(self):

    def called_fn():
      return 0

    closed_over_list = []
    closed_over_primitive = 1

    def local_fn():
      closed_over_list.append(1)
      local_var = 1
      return called_fn() + local_var + closed_over_primitive

    ns = inspect_utils.getnamespace(local_fn)
    self.assertEqual(ns['called_fn'], called_fn)
    self.assertEqual(ns['closed_over_list'], closed_over_list)
    self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
    self.assertTrue('local_var' not in ns)
コード例 #9
0
 def test_getnamespace_globals(self):
     ns = inspect_utils.getnamespace(factory)
     self.assertEqual(ns['free_function'], free_function)
コード例 #10
0
 def test_getnamespace_globals(self):
   ns = inspect_utils.getnamespace(factory)
   self.assertEqual(ns['free_function'], free_function)