def testGetModule(self):
   self.assertEqual(
       inspect.getmodule(TestDecoratedClass),
       tf_inspect.getmodule(TestDecoratedClass))
   self.assertEqual(
       inspect.getmodule(test_decorated_function),
       tf_inspect.getmodule(test_decorated_function))
   self.assertEqual(
       inspect.getmodule(test_undecorated_function),
       tf_inspect.getmodule(test_undecorated_function))
Exemple #2
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(
          six.itervalues(conversion_map.dependency_cache))))

  return imports + '\n\n' + code
Exemple #3
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
Exemple #4
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  module = gast.Module([])
  for dep in reversed(program_ctx.dependency_cache.values()):
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(
      module, source_prefix=program_ctx.required_imports)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_node.__dict__:
      compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Exemple #5
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    for key, val in inspect_utils.getnamespace(e).items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_node.__dict__:
        compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Exemple #6
0
def getqualifiedname(namespace, object_, max_depth=5, visited=None):
  """Returns the name by which a value can be referred to in a given namespace.

  If the object defines a parent module, the function attempts to use it to
  locate the object.

  This function will recurse inside modules, but it will not search objects for
  attributes. The recursion depth is controlled by max_depth.

  Args:
    namespace: Dict[str, Any], the namespace to search into.
    object_: Any, the value to search.
    max_depth: Optional[int], a limit to the recursion depth when searching
        inside modules.
    visited: Optional[Set[int]], ID of modules to avoid visiting.
  Returns: Union[str, None], the fully-qualified name that resolves to the value
      o, or None if it couldn't be found.
  """
  if visited is None:
    visited = set()

  for name in namespace:
    # The value may be referenced by more than one symbol, case in which
    # any symbol will be fine. If the program contains symbol aliases that
    # change over time, this may capture a symbol that will later point to
    # something else.
    # TODO(mdan): Prefer the symbol that matches the value type name.
    if object_ is namespace[name]:
      return name

  # If an object is not found, try to search its parent modules.
  parent = tf_inspect.getmodule(object_)
  if (parent is not None and parent is not object_ and
      parent is not namespace):
    # No limit to recursion depth because of the guard above.
    parent_name = getqualifiedname(
        namespace, parent, max_depth=0, visited=visited)
    if parent_name is not None:
      name_in_parent = getqualifiedname(
          parent.__dict__, object_, max_depth=0, visited=visited)
      assert name_in_parent is not None, (
          'An object should always be found in its owner module')
      return '{}.{}'.format(parent_name, name_in_parent)

  if max_depth:
    # Iterating over a copy prevents "changed size due to iteration" errors.
    # It's unclear why those occur - suspecting new modules may load during
    # iteration.
    for name in namespace.keys():
      value = namespace[name]
      if tf_inspect.ismodule(value) and id(value) not in visited:
        visited.add(id(value))
        name_in_module = getqualifiedname(value.__dict__, object_,
                                          max_depth - 1, visited)
        if name_in_module is not None:
          return '{}.{}'.format(name, name_in_module)
  return None
Exemple #7
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
  """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    The converted code as string.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=converter.Verbosity.BRIEF,
          strip_decorators=(convert, do_not_convert, converted_call),
          optional_features=experimental_optional_features),
      partial_types=experimental_partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  m = tf_inspect.getmodule(o)
  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      return True
  return False
Exemple #9
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          strip_decorators=(convert, do_not_convert, converted_call)),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
Exemple #10
0
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)
  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      return True

  if hasattr(o, 'autograph_info__'):
    return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.log_first_n(
          logging.level_warning(),
          'Entity {} looks like a namedtuple subclass. If it has any custom'
          ' methods, they will not be converted by AutoGraph.'.format(o), 1)
    return True

  return False
Exemple #11
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
Exemple #12
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    # TODO(b/129431421): Remove these args.
    del arg_values
    del arg_types
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            optional_features=experimental_optional_features),
        autograph_module=tf_inspect.getmodule(to_graph))
    return conversion.convert(entity, program_ctx)
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
Exemple #13
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None):
  """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
      ConversionOptions.strip_decorators.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
  if strip_decorators is None:
    strip_decorators = ()
  strip_decorators += (convert, do_not_convert, converted_call)

  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=verbose,
          strip_decorators=strip_decorators),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.conversion_order):
    nodes.extend(program_ctx.dependency_cache[dep])

  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  # TODO(mdan): Record this statically in the generated code.
  # TODO(mdan): Rename this attribute to 'autograph_info__'
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled, source_map_attribute_name))
  setattr(compiled, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  return compiled
Exemple #14
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
Exemple #15
0
def to_graph(entity, recursive=True, experimental_optional_features=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  Example usage:

  >>> def f(x):
  ...   if x > 0:
  ...     y = x * x
  ...   else:
  ...     y = -x
  ...   return y
  ...
  >>> converted_f = to_graph(f)
  >>> x = tf.constant(2)
  >>> converted_f(x)  # converted_foo is like a TensorFlow Op.
  <tf.Tensor: shape=(), dtype=int32, numpy=4>

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  For a tutorial, see the
  [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
  For more detailed information, see the
  [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                user_requested=True,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        return conversion.convert(entity, program_ctx)
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        logging.error(1, 'Error converting %s', entity, exc_info=True)
        raise ConversionError('converting {}: {}: {}'.format(
            entity, e.__class__.__name__, str(e)))
Exemple #16
0
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
    """Compiles a function call inline.

  For internal use only.

  Note: The argument list is optimized for readability of generated code, which
  may look like this:

    ag__.converted_call(f, (arg1, arg2), None, fscope)
    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)

  Args:
    f: The function to convert.
    args: Tuple, the original positional arguments of f
    kwargs: Optional[Dict], the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.
    options: Optional[converter.ConversionOptions], conversion options. If not
      specified, the value of caller_fn_scope.callopts is used. Either options
      or caller_fn_scope must be present.

  Returns:
    Any, the result of executing a possibly-converted `f` with the given
      arguments.
  """
    logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f,
                args, kwargs)

    if options is None:
        if caller_fn_scope is None:
            raise ValueError(
                'either caller_fn_scope or options must have a value')
        options = caller_fn_scope.callopts

    if conversion.is_in_whitelist_cache(f, options):
        logging.log(2, 'Whitelisted %s: from cache')
        return _call_unconverted(f, args, kwargs, options, False)

    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
        logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
        return _call_unconverted(f, args, kwargs, options, False)

    if inspect_utils.isbuiltin(f):
        if f is eval:
            return py_builtins.eval_in_original_context(
                f, args, caller_fn_scope)
        if f is super:
            return py_builtins.super_in_original_context(
                f, args, caller_fn_scope)
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    # TODO(mdan): Clean up the naming inconsistency.
    if hasattr(f, 'autograph_info__') or hasattr(f, '__ag_compiled'):
        logging.log(2, 'Permanently whitelisted: %s: already converted', f)
        return _call_unconverted(f, args, kwargs, options)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            '{} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will run as-is.'
            ' You may still apply AutoGraph before the wrapt decorator.'.
            format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs, options)

    if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
        return _call_unconverted(f, args, kwargs, options)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs, options)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, inspect, re)):
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs, options)

    # Custom ops and kernels are also permanently whitelisted.
    # See tensorflow.framework.load_library.
    if (hasattr(f, '__module__')
            and hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
        logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
        return _call_unconverted(f, args, kwargs, options)

    if not options.user_requested and conversion.is_whitelisted(f):
        return _call_unconverted(f, args, kwargs, options)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs, options)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        # TODO(b/120224672): This unwrapping should be done before the checks above.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        if not tf_inspect.isclass(target_entity):
            if not hasattr(target_entity, '__code__'):
                logging.log(2, 'Permanently whitelisted: %s: native binding',
                            target_entity)
                return _call_unconverted(f, args, kwargs, options)
            elif (hasattr(target_entity.__code__, 'co_filename')
                  and target_entity.__code__.co_filename == '<string>'):
                # TODO(mdan): __globals__['txt'] might work in Py3.
                logging.log(
                    2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                    target_entity)
                return _call_unconverted(f, args, kwargs, options)

        program_ctx = converter.ProgramContext(
            options=options,
            autograph_module=tf_inspect.getmodule(converted_call))
        converted_f = conversion.convert(target_entity, program_ctx)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if six.PY3:
                logging.log(2, 'KW defaults of %s : %s', converted_f,
                            converted_f.__kwdefaults__)

            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)

            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise

        if isinstance(e, errors.UnsupportedLanguageElementError):
            # Repeating the check made upon function entry because the state might
            # have updated in the meantime.
            if not conversion.is_in_whitelist_cache(f, options):
                logging.warn(
                    'AutoGraph could not transform %s and will run it as-is.\n'
                    'Cause: %s', target_entity, e)
        else:
            logging.warn(
                'AutoGraph could not transform %s and will run it as-is.\n'
                'Please report this to the TensorFlow team. When filing the bug, set'
                ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
                ' attach the full output.\n'
                'Cause: %s', target_entity, e)

        return _call_unconverted(f, args, kwargs, options)

    with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
        try:
            if kwargs is not None:
                result = converted_f(*effective_args, **kwargs)
            else:
                result = converted_f(*effective_args)
        except Exception as e:
            _attach_metadata(e, converted_f, True)
            raise

    return result
Exemple #17
0
def to_graph(entity, recursive=True, experimental_optional_features=None):
    """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of optional
      features in the conversion process.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
    try:
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                optional_features=experimental_optional_features),
            autograph_module=tf_inspect.getmodule(to_graph))
        return conversion.convert(entity, program_ctx)
    except (ValueError, AttributeError, KeyError, NameError,
            AssertionError) as e:
        logging.error(1, 'Error converting %s', entity, exc_info=True)
        raise ConversionError('converting {}: {}: {}'.format(
            entity, e.__class__.__name__, str(e)))
Exemple #18
0
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)
  if not hasattr(m, '__name__'):
    # Note: typically it's builtins that fall in this category. Builtins will
    # be handled by specific code that follows this screening layer.
    logging.log(2, '%s is NOT whitelisted: unknown module name', o)
    return False

  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
      return True

  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
    logging.log(2, '%s is whitelisted: already converted', o)
    return True

  if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and
      hasattr(o, '__call__') and hasattr(o, '__class__')):
    # Callable objects: whitelisted if their __call__ method is.
    call_whitelisted = is_whitelisted_for_graph(o.__call__)
    if call_whitelisted:
      logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
      return call_whitelisted

  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
                    owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.warn_first_n(
          'Entity {} looks like a namedtuple subclass. If it has any custom'
          ' methods, they will not be converted by AutoGraph.'.format(o), 1)
    logging.log(2, '%s is whitelisted: named tuple', o)
    return True

  logging.log(2, '%s is NOT whitelisted', o)
  return False
Exemple #19
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
    creates TF a graph that has the same functionality as the original entity.
  Raises:
    ValueError: If the converted function defines or refers to symbol names that
    are reserved for AutoGraph.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  nodes = []
  for dep in reversed(program_ctx.dependency_cache.values()):
    nodes.extend(dep)
  compiled_module, compiled_src = compiler.ast_to_object(
      nodes,
      source_prefix=program_ctx.required_imports,
      include_source_map=True)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_module.__dict__:
      compiled_module.__dict__[key] = val
  compiled_fn = getattr(compiled_module, name)

  # Need this so the source_mapping attribute is available for the context
  # manager to access for runtime errors.
  #
  # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
  # symbol to the compiled module.
  source_map_attribute_name = 'ag_source_map'
  if getattr(compiled_fn, source_map_attribute_name, None) is not None:
    raise ValueError('cannot convert %s because is has an attribute '
                     '"%s", which is reserved for AutoGraph.' %
                     (compiled_fn, source_map_attribute_name))
  setattr(compiled_fn, source_map_attribute_name,
          compiled_module.__dict__['ag_source_map__'])

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Exemple #20
0
def is_whitelisted_for_graph(o):
  """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.

  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)

  if hasattr(m, '__name__'):
    # Builtins typically have unnamed modules.
    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
      if m.__name__.startswith(prefix):
        logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
        return True

    # Temporary -- whitelist tensorboard modules.
    # TODO(b/122731813): Remove.
    if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
      logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
      return True

  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
    logging.log(2, 'Whitelisted: %s: already converted', o)
    return True

  if tf_inspect.isgeneratorfunction(o):
    logging.warn(
        'Entity {} appears to be a generator function. It will not be converted'
        ' by AutoGraph.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
    return True

  if hasattr(o, '__call__'):
    # Callable objects: whitelisted if their __call__ method is.
    # The type check avoids infinite recursion around the __call__ method
    # of function objects.
    if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__):  # pylint: disable=unidiomatic-typecheck
      logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
      return True

  owner_class = None
  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      if issubclass(owner_class, unittest.TestCase):
        logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
        return True

      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                    owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.warn(
          'Entity {} looks like a namedtuple subclass. Its constructor will'
          ' not be converted by AutoGraph, but if it has any custom methods,'
          ' those will be.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: named tuple', o)
    return True

  logging.log(2, 'Not whitelisted: %s: default rule', o)
  return False
def getqualifiedname(namespace, object_, max_depth=5, visited=None):
    """Returns the name by which a value can be referred to in a given namespace.

  If the object defines a parent module, the function attempts to use it to
  locate the object.

  This function will recurse inside modules, but it will not search objects for
  attributes. The recursion depth is controlled by max_depth.

  Args:
    namespace: Dict[str, Any], the namespace to search into.
    object_: Any, the value to search.
    max_depth: Optional[int], a limit to the recursion depth when searching
        inside modules.
    visited: Optional[Set[int]], ID of modules to avoid visiting.
  Returns: Union[str, None], the fully-qualified name that resolves to the value
      o, or None if it couldn't be found.
  """
    if visited is None:
        visited = set()

    # Copy the dict to avoid "changed size error" during concurrent invocations.
    # TODO(mdan): This is on the hot path. Can we avoid the copy?
    namespace = dict(namespace)

    for name in namespace:
        # The value may be referenced by more than one symbol, case in which
        # any symbol will be fine. If the program contains symbol aliases that
        # change over time, this may capture a symbol that will later point to
        # something else.
        # TODO(mdan): Prefer the symbol that matches the value type name.
        if object_ is namespace[name]:
            return name

    # If an object is not found, try to search its parent modules.
    parent = tf_inspect.getmodule(object_)
    if (parent is not None and parent is not object_
            and parent is not namespace):
        # No limit to recursion depth because of the guard above.
        parent_name = getqualifiedname(namespace,
                                       parent,
                                       max_depth=0,
                                       visited=visited)
        if parent_name is not None:
            name_in_parent = getqualifiedname(parent.__dict__,
                                              object_,
                                              max_depth=0,
                                              visited=visited)
            assert name_in_parent is not None, (
                'An object should always be found in its owner module')
            return '{}.{}'.format(parent_name, name_in_parent)

    if max_depth:
        # Iterating over a copy prevents "changed size due to iteration" errors.
        # It's unclear why those occur - suspecting new modules may load during
        # iteration.
        for name in namespace.keys():
            value = namespace[name]
            if tf_inspect.ismodule(value) and id(value) not in visited:
                visited.add(id(value))
                name_in_module = getqualifiedname(value.__dict__, object_,
                                                  max_depth - 1, visited)
                if name_in_module is not None:
                    return '{}.{}'.format(name, name_in_module)
    return None
Exemple #22
0
def is_whitelisted_for_graph(o):
    """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    if hasattr(m, '__name__'):
        # Builtins typically have unnamed modules.
        for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
            if m.__name__.startswith(prefix):
                logging.log(2, 'Whitelisted: %s: name starts with "%s"', o,
                            prefix)
                return True

        # Temporary -- whitelist tensorboard modules.
        # TODO(b/122731813): Remove.
        if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
            logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
            return True

    if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
        logging.log(2, 'Whitelisted: %s: already converted', o)
        return True

    if hasattr(o, '__call__'):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(
                o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted_for_graph(owner_class):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.warn(
                'Entity {} looks like a namedtuple subclass. Its constructor will'
                ' not be converted by AutoGraph, but if it has any custom methods,'
                ' those will be.'.format(o), 1)
        logging.log(2, 'Whitelisted: %s: named tuple', o)
        return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
Exemple #23
0
def is_whitelisted_for_graph(o,
                             check_call_override=True,
                             allow_namedtuple_subclass=False):
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.
    allow_namedtuple_subclass: Reserved for internal use. When `True`,
      namedtuple subclasses are not whitelisted.

  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    # Examples of callables that lack a __module__ property include builtins.
    if hasattr(m, '__name__'):
        for rule in config.CONVERSION_RULES:
            action = rule.get_action(m)
            if action == config.Action.CONVERT:
                logging.log(2, 'Not whitelisted: %s: %s', o, rule)
                return False
            elif action == config.Action.DO_NOT_CONVERT:
                logging.log(2, 'Whitelisted: %s: %s', o, rule)
                return True

    if tf_inspect.isgeneratorfunction(o):
        logging.warn(
            'Entity %s appears to be a generator function. It will not be converted'
            ' by AutoGraph.', o)
        logging.log(2,
                    'Whitelisted: %s: generator functions are not converted',
                    o)
        return True

    if (check_call_override and not tf_inspect.isclass(o)
            and hasattr(o, '__call__')):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(
                o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted_for_graph(owner_class,
                                        check_call_override=False,
                                        allow_namedtuple_subclass=True):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if allow_namedtuple_subclass:
            if not any(
                    inspect_utils.isnamedtuple(base) for base in o.__bases__):
                logging.log(2, 'Whitelisted: %s: named tuple', o)
                return True
        else:
            logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
            return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
Exemple #24
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None,
             strip_decorators=None,
             optional_features=converter.Feature.ALL):
    """Converts a Python entity into equivalent code that uses TensorFlow ops.

  Supported Python entities include:
    * functions
    * classes

  Classes are converted by converting all their methods into a new class.

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    verbose: bool, whether to output the compiled code in the logs.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    strip_decorators: Tuple[Callable], same as
      ConversionOptions.strip_decorators.
    optional_features: Union[Feature, Set[Feature]], same as
      ConversionOptions.optional_features.

  Returns:
    Union[Callable, Type], the converted entity, which is the same kind as e
    (that is, a function is e is a function, a class if e is a class, etc.) but
    its code has been converted to use TF ops.

  Raises:
    ValueError: If the entity could not be converted.
  """
    if strip_decorators is None:
        strip_decorators = ()
    strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=verbose,
            strip_decorators=strip_decorators,
            optional_features=optional_features),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                    arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
        nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
        # Avoid overwriting entities that have been transformed.
        if key not in compiled_module.__dict__:
            compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if tf_inspect.isfunction(e):
        compiled.__defaults__ = e.__defaults__

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
        raise ValueError('cannot convert %s because is has an attribute '
                         '"%s", which is reserved for AutoGraph.' %
                         (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
Exemple #25
0
def is_whitelisted_for_graph(o):
    """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)
    if not hasattr(m, '__name__'):
        # Note: typically it's builtins that fall in this category. Builtins will
        # be handled by specific code that follows this screening layer.
        logging.log(2, '%s is NOT whitelisted: unknown module name', o)
        return False

    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
        if m.__name__.startswith(prefix):
            logging.log(2, '%s is whitelisted: name starts with "%s"', o,
                        prefix)
            return True

    if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
        logging.log(2, '%s is whitelisted: already converted', o)
        return True

    if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o)
            and hasattr(o, '__call__') and hasattr(o, '__class__')):
        # Callable objects: whitelisted if their __call__ method is.
        call_whitelisted = is_whitelisted_for_graph(o.__call__)
        if call_whitelisted:
            logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
            return call_whitelisted

    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted_for_graph(owner_class):
                logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.warn_first_n(
                'Entity {} looks like a namedtuple subclass. If it has any custom'
                ' methods, they will not be converted by AutoGraph.'.format(o),
                1)
        logging.log(2, '%s is whitelisted: named tuple', o)
        return True

    logging.log(2, '%s is NOT whitelisted', o)
    return False